diff --git a/cassandra_test.go b/cassandra_test.go index 3b0c61053..d8d37baac 100644 --- a/cassandra_test.go +++ b/cassandra_test.go @@ -242,7 +242,7 @@ func TestObserve(t *testing.T) { t.Fatal("select: unexpected observed stmt", observedStmt) } - // reports errors when the query is poorly formed + // reports internal_errors when the query is poorly formed resetObserved() value = 0 if err := session.Query(`SELECT id FROM unknown_table WHERE id = ?`, 42).Observer(observer).Scan(&value); err == nil { @@ -1451,8 +1451,8 @@ func injectInvalidPreparedStatement(t *testing.T, session *Session, table string flight.preparedStatment = &preparedStatment{ id: []byte{'f', 'o', 'o', 'b', 'a', 'r'}, - request: preparedMetadata{ - resultMetadata: resultMetadata{ + request: protocol.PreparedMetadata{ + protocol.ResultMetadata: protocol.ResultMetadata{ colCount: 1, actualColCount: 1, columns: []ColumnInfo{ diff --git a/compressor.go b/compressor.go index f3d451a9f..43b32472c 100644 --- a/compressor.go +++ b/compressor.go @@ -25,28 +25,9 @@ package gocql import ( - "github.com/golang/snappy" + "github.com/gocql/gocql/internal/compressor" ) -type Compressor interface { - Name() string - Encode(data []byte) ([]byte, error) - Decode(data []byte) ([]byte, error) -} +type Compressor = compressor.Compressor -// SnappyCompressor implements the Compressor interface and can be used to -// compress incoming and outgoing frames. The snappy compression algorithm -// aims for very high speeds and reasonable compression. -type SnappyCompressor struct{} - -func (s SnappyCompressor) Name() string { - return "snappy" -} - -func (s SnappyCompressor) Encode(data []byte) ([]byte, error) { - return snappy.Encode(nil, data), nil -} - -func (s SnappyCompressor) Decode(data []byte) ([]byte, error) { - return snappy.Decode(nil, data) -} +type SnappyCompressor = compressor.SnappyCompressor diff --git a/conn.go b/conn.go index ae02bd71c..60ed7319b 100644 --- a/conn.go +++ b/conn.go @@ -30,6 +30,8 @@ import ( "crypto/tls" "errors" "fmt" + "github.com/gocql/gocql/internal/internal_errors" + "github.com/gocql/gocql/internal/protocol" "io" "io/ioutil" "net" @@ -188,7 +190,7 @@ type Conn struct { frameObserver FrameHeaderObserver streamObserver StreamObserver - headerBuf [maxFrameHeaderSize]byte + headerBuf [protocol.MaxFrameHeaderSize]byte streams *streams.IDGenerator mu sync.Mutex @@ -418,21 +420,21 @@ func (s *startupCoordinator) write(ctx context.Context, frame frameBuilder) (fra return nil, err } - return framer.parseFrame() + return framer.ParseFrame() } func (s *startupCoordinator) options(ctx context.Context) error { - frame, err := s.write(ctx, &writeOptionsFrame{}) + frame, err := s.write(ctx, &protocol.WriteOptionsFrame{}) if err != nil { return err } - supported, ok := frame.(*supportedFrame) + supported, ok := frame.(*protocol.SupportedFrame) if !ok { - return NewErrProtocol("Unknown type of response to startup frame: %T", frame) + return protocol.NewErrProtocol("Unknown type of response to startup frame: %T", frame) } - return s.startup(ctx, supported.supported) + return s.startup(ctx, supported.Supported) } func (s *startupCoordinator) startup(ctx context.Context, supported map[string][]string) error { @@ -457,7 +459,7 @@ func (s *startupCoordinator) startup(ctx context.Context, supported map[string][ } } - frame, err := s.write(ctx, &writeStartupFrame{opts: m}) + frame, err := s.write(ctx, &protocol.WriteStartupFrame{Opts: m}) if err != nil { return err } @@ -465,26 +467,26 @@ func (s *startupCoordinator) startup(ctx context.Context, supported map[string][ switch v := frame.(type) { case error: return v - case *readyFrame: + case *protocol.ReadyFrame: return nil - case *authenticateFrame: + case *protocol.AuthenticateFrame: return s.authenticateHandshake(ctx, v) default: - return NewErrProtocol("Unknown type of response to startup frame: %s", v) + return protocol.NewErrProtocol("Unknown type of response to startup frame: %s", v) } } -func (s *startupCoordinator) authenticateHandshake(ctx context.Context, authFrame *authenticateFrame) error { +func (s *startupCoordinator) authenticateHandshake(ctx context.Context, authFrame *protocol.AuthenticateFrame) error { if s.conn.auth == nil { - return fmt.Errorf("authentication required (using %q)", authFrame.class) + return fmt.Errorf("authentication required (using %q)", authFrame.Class) } - resp, challenger, err := s.conn.auth.Challenge([]byte(authFrame.class)) + resp, challenger, err := s.conn.auth.Challenge([]byte(authFrame.Class)) if err != nil { return err } - req := &writeAuthResponseFrame{data: resp} + req := &protocol.WriteAuthResponseFrame{Data: resp} for { frame, err := s.write(ctx, req) if err != nil { @@ -494,19 +496,19 @@ func (s *startupCoordinator) authenticateHandshake(ctx context.Context, authFram switch v := frame.(type) { case error: return v - case *authSuccessFrame: + case *protocol.AuthSuccessFrame: if challenger != nil { - return challenger.Success(v.data) + return challenger.Success(v.Data) } return nil - case *authChallengeFrame: - resp, challenger, err = challenger.Challenge(v.data) + case *protocol.AuthChallengeFrame: + resp, challenger, err = challenger.Challenge(v.Data) if err != nil { return err } - req = &writeAuthResponseFrame{ - data: resp, + req = &protocol.WriteAuthResponseFrame{ + Data: resp, } default: return fmt.Errorf("unknown frame response during authentication: %v", v) @@ -585,8 +587,8 @@ func (c *Conn) serve(ctx context.Context) { c.closeWithError(err) } -func (c *Conn) discardFrame(head frameHeader) error { - _, err := io.CopyN(ioutil.Discard, c, int64(head.length)) +func (c *Conn) discardFrame(head protocol.FrameHeader) error { + _, err := io.CopyN(ioutil.Discard, c, int64(head.Length)) if err != nil { return err } @@ -601,7 +603,7 @@ func (p *protocolError) Error() string { if err, ok := p.frame.(error); ok { return err.Error() } - return fmt.Sprintf("gocql: received unexpected frame on stream %d: %v", p.frame.Header().stream, p.frame) + return fmt.Sprintf("gocql: received unexpected frame on stream %d: %v", p.frame.Header().Stream, p.frame) } func (c *Conn) heartBeat(ctx context.Context) { @@ -625,13 +627,13 @@ func (c *Conn) heartBeat(ctx context.Context) { case <-timer.C: } - framer, err := c.exec(context.Background(), &writeOptionsFrame{}, nil) + framer, err := c.exec(context.Background(), &protocol.WriteOptionsFrame{}, nil) if err != nil { failures++ continue } - resp, err := framer.parseFrame() + resp, err := framer.ParseFrame() if err != nil { // invalid frame failures++ @@ -639,7 +641,7 @@ func (c *Conn) heartBeat(ctx context.Context) { } switch resp.(type) { - case *supportedFrame: + case *protocol.SupportedFrame: // Everything ok sleepTime = 5 * time.Second failures = 0 @@ -662,7 +664,7 @@ func (c *Conn) recv(ctx context.Context) error { headStartTime := time.Now() // were just reading headers over and over and copy bodies - head, err := readHeader(c.r, c.headerBuf[:]) + head, err := protocol.ReadHeader(c.r, c.headerBuf[:]) headEndTime := time.Now() if err != nil { return err @@ -670,36 +672,36 @@ func (c *Conn) recv(ctx context.Context) error { if c.frameObserver != nil { c.frameObserver.ObserveFrameHeader(context.Background(), ObservedFrameHeader{ - Version: protoVersion(head.version), - Flags: head.flags, - Stream: int16(head.stream), - Opcode: frameOp(head.op), - Length: int32(head.length), + Version: protocol.ProtoVersion(head.Version), + Flags: head.Flags, + Stream: int16(head.Stream), + Opcode: protocol.FrameOp(head.Op), + Length: int32(head.Length), Start: headStartTime, End: headEndTime, Host: c.host, }) } - if head.stream > c.streams.NumStreams { - return fmt.Errorf("gocql: frame header stream is beyond call expected bounds: %d", head.stream) - } else if head.stream == -1 { + if head.Stream > c.streams.NumStreams { + return fmt.Errorf("gocql: frame header stream is beyond call expected bounds: %d", head.Stream) + } else if head.Stream == -1 { // TODO: handle cassandra event frames, we shouldnt get any currently - framer := newFramer(c.compressor, c.version) - if err := framer.readFrame(c, &head); err != nil { + framer := protocol.NewFramer(c.compressor, c.version) + if err := framer.ReadFrame(c, &head); err != nil { return err } go c.session.handleEvent(framer) return nil - } else if head.stream <= 0 { + } else if head.Stream <= 0 { // reserved stream that we dont use, probably due to a protocol error // or a bug in Cassandra, this should be an error, parse it and return. - framer := newFramer(c.compressor, c.version) - if err := framer.readFrame(c, &head); err != nil { + framer := protocol.NewFramer(c.compressor, c.version) + if err := framer.ReadFrame(c, &head); err != nil { return err } - frame, err := framer.parseFrame() + frame, err := framer.ParseFrame() if err != nil { return err } @@ -714,21 +716,21 @@ func (c *Conn) recv(ctx context.Context) error { c.mu.Unlock() return ErrConnectionClosed } - call, ok := c.calls[head.stream] - delete(c.calls, head.stream) + call, ok := c.calls[head.Stream] + delete(c.calls, head.Stream) c.mu.Unlock() if call == nil || !ok { c.logger.Printf("gocql: received response for stream which has no handler: header=%v\n", head) return c.discardFrame(head) - } else if head.stream != call.streamID { - panic(fmt.Sprintf("call has incorrect streamID: got %d expected %d", call.streamID, head.stream)) + } else if head.Stream != call.streamID { + panic(fmt.Sprintf("call has incorrect streamID: got %d expected %d", call.streamID, head.Stream)) } - framer := newFramer(c.compressor, c.version) + framer := protocol.NewFramer(c.compressor, c.version) - err = framer.readFrame(c, &head) + err = framer.ReadFrame(c, &head) if err != nil { - // only net errors should cause the connection to be closed. Though + // only net internal_errors should cause the connection to be closed. Though // cassandra returning corrupt frames will be returned here as well. if _, ok := err.(net.Error); ok { return err @@ -788,7 +790,7 @@ type callReq struct { type callResp struct { // framer is the response frame. // May be nil if err is not nil. - framer *framer + framer *protocol.Framer // err is error encountered, if any. err error } @@ -1018,7 +1020,7 @@ func (c *Conn) addCall(call *callReq) error { return nil } -func (c *Conn) exec(ctx context.Context, req frameBuilder, tracer Tracer) (*framer, error) { +func (c *Conn) exec(ctx context.Context, req frameBuilder, tracer Tracer) (*protocol.Framer, error) { if ctxErr := ctx.Err(); ctxErr != nil { return nil, ctxErr } @@ -1030,7 +1032,7 @@ func (c *Conn) exec(ctx context.Context, req frameBuilder, tracer Tracer) (*fram } // resp is basically a waiting semaphore protecting the framer - framer := newFramer(c.compressor, c.version) + framer := protocol.NewFramer(c.compressor, c.version) call := &callReq{ timeout: make(chan struct{}), @@ -1051,7 +1053,7 @@ func (c *Conn) exec(ctx context.Context, req frameBuilder, tracer Tracer) (*fram // If we don't close(call.timeout) or read from call.resp, closeWithError can deadlock. if tracer != nil { - framer.trace() + framer.Trace() } if call.streamObserverContext != nil { @@ -1060,7 +1062,7 @@ func (c *Conn) exec(ctx context.Context, req frameBuilder, tracer Tracer) (*fram }) } - err := req.buildFrame(framer, stream) + err := req.BuildFrame(framer, stream) if err != nil { // closeWithError will block waiting for this stream to either receive a response // or for us to timeout. @@ -1078,7 +1080,7 @@ func (c *Conn) exec(ctx context.Context, req frameBuilder, tracer Tracer) (*fram return nil, err } - n, err := c.w.writeContext(ctx, framer.buf) + n, err := c.w.writeContext(ctx, framer.Buf) if err != nil { // closeWithError will block waiting for this stream to either receive a response // or for us to timeout, close the timeout chan here. Im not entirely sure @@ -1150,8 +1152,8 @@ func (c *Conn) exec(ctx context.Context, req frameBuilder, tracer Tracer) (*fram // requests on the stream to prevent nil pointer dereferences in recv(). defer c.releaseStream(call) - if v := resp.framer.header.version.version(); v != c.version { - return nil, NewErrProtocol("unexpected protocol version in response: got %d expected %d", v, c.version) + if v := resp.framer.Header.Version.Version(); v != c.version { + return nil, protocol.NewErrProtocol("unexpected protocol version in response: got %d expected %d", v, c.version) } return resp.framer, nil @@ -1216,8 +1218,8 @@ type StreamObserverContext interface { type preparedStatment struct { id []byte - request preparedMetadata - response resultMetadata + request protocol.PreparedMetadata + response protocol.ResultMetadata } type inflightPrepare struct { @@ -1241,11 +1243,11 @@ func (c *Conn) prepareStatement(ctx context.Context, stmt string, tracer Tracer) go func() { defer close(flight.done) - prep := &writePrepareFrame{ - statement: stmt, + prep := &protocol.WritePrepareFrame{ + Statement: stmt, } if c.version > protoVersion4 { - prep.keyspace = c.currentKeyspace + prep.Keyspace = c.currentKeyspace } // we won the race to do the load, if our context is canceled we shouldnt @@ -1258,7 +1260,7 @@ func (c *Conn) prepareStatement(ctx context.Context, stmt string, tracer Tracer) return } - frame, err := framer.parseFrame() + frame, err := framer.ParseFrame() if err != nil { flight.err = err c.session.stmtsLRU.remove(stmtCacheKey) @@ -1267,25 +1269,25 @@ func (c *Conn) prepareStatement(ctx context.Context, stmt string, tracer Tracer) // TODO(zariel): tidy this up, simplify handling of frame parsing so its not duplicated // everytime we need to parse a frame. - if len(framer.traceID) > 0 && tracer != nil { - tracer.Trace(framer.traceID) + if len(framer.TraceID) > 0 && tracer != nil { + tracer.Trace(framer.TraceID) } switch x := frame.(type) { - case *resultPreparedFrame: + case *protocol.ResultPreparedFrame: flight.preparedStatment = &preparedStatment{ // defensively copy as we will recycle the underlying buffer after we // return. - id: copyBytes(x.preparedID), + id: protocol.CopyBytes(x.PreparedID), // the type info's should _not_ have a reference to the framers read buffer, // therefore we can just copy them directly. - request: x.reqMeta, - response: x.respMeta, + request: x.ReqMeta, + response: x.RespMeta, } case error: flight.err = x default: - flight.err = NewErrProtocol("Unknown type in response to prepare frame: %s", x) + flight.err = protocol.NewErrProtocol("Unknown type in response to prepare frame: %s", x) } if flight.err != nil { @@ -1302,44 +1304,44 @@ func (c *Conn) prepareStatement(ctx context.Context, stmt string, tracer Tracer) } } -func marshalQueryValue(typ TypeInfo, value interface{}, dst *queryValues) error { - if named, ok := value.(*namedValue); ok { - dst.name = named.name - value = named.value +func marshalQueryValue(typ TypeInfo, value interface{}, dst *protocol.QueryValues) error { + if named, ok := value.(*protocol.NamedValue); ok { + dst.Name = named.Name + value = named.Value } - if _, ok := value.(unsetColumn); !ok { + if _, ok := value.(protocol.UnsetColumn); !ok { val, err := Marshal(typ, value) if err != nil { return err } - dst.value = val + dst.Value = val } else { - dst.isUnset = true + dst.IsUnset = true } return nil } func (c *Conn) executeQuery(ctx context.Context, qry *Query) *Iter { - params := queryParams{ - consistency: qry.cons, + params := protocol.QueryParams{ + Consistency: qry.cons, } // frame checks that it is not 0 - params.serialConsistency = qry.serialCons - params.defaultTimestamp = qry.defaultTimestamp - params.defaultTimestampValue = qry.defaultTimestampValue + params.SerialConsistency = qry.serialCons + params.DefaultTimestamp = qry.defaultTimestamp + params.DefaultTimestampValue = qry.defaultTimestampValue if len(qry.pageState) > 0 { - params.pagingState = qry.pageState + params.PagingState = qry.pageState } if qry.pageSize > 0 { - params.pageSize = qry.pageSize + params.PageSize = qry.pageSize } if c.version > protoVersion4 { - params.keyspace = c.currentKeyspace + params.Keyspace = c.currentKeyspace } var ( @@ -1359,9 +1361,9 @@ func (c *Conn) executeQuery(ctx context.Context, qry *Query) *Iter { if qry.binding != nil { values, err = qry.binding(&QueryInfo{ Id: info.id, - Args: info.request.columns, - Rval: info.response.columns, - PKeyColumns: info.request.pkeyColumns, + Args: info.request.Columns, + Rval: info.response.Columns, + PKeyColumns: info.request.PkeyColumns, }) if err != nil { @@ -1369,38 +1371,38 @@ func (c *Conn) executeQuery(ctx context.Context, qry *Query) *Iter { } } - if len(values) != info.request.actualColCount { - return &Iter{err: fmt.Errorf("gocql: expected %d values send got %d", info.request.actualColCount, len(values))} + if len(values) != info.request.ActualColCount { + return &Iter{err: fmt.Errorf("gocql: expected %d values send got %d", info.request.ActualColCount, len(values))} } - params.values = make([]queryValues, len(values)) + params.Values = make([]protocol.QueryValues, len(values)) for i := 0; i < len(values); i++ { - v := ¶ms.values[i] + v := ¶ms.Values[i] value := values[i] - typ := info.request.columns[i].TypeInfo + typ := info.request.Columns[i].TypeInfo if err := marshalQueryValue(typ, value, v); err != nil { return &Iter{err: err} } } - params.skipMeta = !(c.session.cfg.DisableSkipMetadata || qry.disableSkipMetadata) + params.SkipMeta = !(c.session.cfg.DisableSkipMetadata || qry.disableSkipMetadata) - frame = &writeExecuteFrame{ - preparedID: info.id, - params: params, - customPayload: qry.customPayload, + frame = &protocol.WriteExecuteFrame{ + PreparedID: info.id, + Params: params, + CustomPayload: qry.customPayload, } // Set "keyspace" and "table" property in the query if it is present in preparedMetadata qry.routingInfo.mu.Lock() - qry.routingInfo.keyspace = info.request.keyspace - qry.routingInfo.table = info.request.table + qry.routingInfo.keyspace = info.request.Keyspace + qry.routingInfo.table = info.request.Table qry.routingInfo.mu.Unlock() } else { - frame = &writeQueryFrame{ - statement: qry.stmt, - params: params, - customPayload: qry.customPayload, + frame = &protocol.WriteQueryFrame{ + Statement: qry.stmt, + Params: params, + CustomPayload: qry.customPayload, } } @@ -1409,45 +1411,45 @@ func (c *Conn) executeQuery(ctx context.Context, qry *Query) *Iter { return &Iter{err: err} } - resp, err := framer.parseFrame() + resp, err := framer.ParseFrame() if err != nil { return &Iter{err: err} } - if len(framer.traceID) > 0 && qry.trace != nil { - qry.trace.Trace(framer.traceID) + if len(framer.TraceID) > 0 && qry.trace != nil { + qry.trace.Trace(framer.TraceID) } switch x := resp.(type) { - case *resultVoidFrame: + case *protocol.ResultVoidFrame: return &Iter{framer: framer} - case *resultRowsFrame: + case *protocol.ResultRowsFrame: iter := &Iter{ - meta: x.meta, + meta: x.Meta, framer: framer, - numRows: x.numRows, + numRows: x.NumRows, } - if params.skipMeta { + if params.SkipMeta { if info != nil { iter.meta = info.response - iter.meta.pagingState = copyBytes(x.meta.pagingState) + iter.meta.PagingState = protocol.CopyBytes(x.Meta.PagingState) } else { return &Iter{framer: framer, err: errors.New("gocql: did not receive metadata but prepared info is nil")} } } else { - iter.meta = x.meta + iter.meta = x.Meta } - if x.meta.morePages() && !qry.disableAutoPage { + if x.Meta.MorePages() && !qry.disableAutoPage { newQry := new(Query) *newQry = *qry - newQry.pageState = copyBytes(x.meta.pagingState) + newQry.pageState = protocol.CopyBytes(x.Meta.PagingState) newQry.metrics = &queryMetrics{m: make(map[string]*hostMetrics)} iter.next = &nextIter{ qry: newQry, - pos: int((1 - qry.prefetch) * float64(x.numRows)), + pos: int((1 - qry.prefetch) * float64(x.NumRows)), } if iter.next.pos < 1 { @@ -1456,9 +1458,9 @@ func (c *Conn) executeQuery(ctx context.Context, qry *Query) *Iter { } return iter - case *resultKeyspaceFrame: + case *protocol.ResultKeyspaceFrame: return &Iter{framer: framer} - case *schemaChangeKeyspace, *schemaChangeTable, *schemaChangeFunction, *schemaChangeAggregate, *schemaChangeType: + case *protocol.SchemaChangeKeyspace, *protocol.SchemaChangeTable, *protocol.SchemaChangeFunction, *protocol.SchemaChangeAggregate, *protocol.SchemaChangeType: iter := &Iter{framer: framer} if err := c.awaitSchemaAgreement(ctx); err != nil { // TODO: should have this behind a flag @@ -1468,7 +1470,7 @@ func (c *Conn) executeQuery(ctx context.Context, qry *Query) *Iter { // though. The impact of this returning an error would be that the cluster // is not consistent with regards to its schema. return iter - case *RequestErrUnprepared: + case *internal_errors.RequestErrUnprepared: stmtCacheKey := c.session.stmtsLRU.keyFor(c.host.HostID(), c.currentKeyspace, qry.stmt) c.session.stmtsLRU.evictPreparedID(stmtCacheKey, x.StatementId) return c.executeQuery(ctx, qry) @@ -1476,7 +1478,7 @@ func (c *Conn) executeQuery(ctx context.Context, qry *Query) *Iter { return &Iter{err: x, framer: framer} default: return &Iter{ - err: NewErrProtocol("Unknown type in response to execute query (%T): %s", x, x), + err: protocol.NewErrProtocol("Unknown type in response to execute query (%T): %s", x, x), framer: framer, } } @@ -1504,25 +1506,25 @@ func (c *Conn) AvailableStreams() int { } func (c *Conn) UseKeyspace(keyspace string) error { - q := &writeQueryFrame{statement: `USE "` + keyspace + `"`} - q.params.consistency = c.session.cons + q := &protocol.WriteQueryFrame{Statement: `USE "` + keyspace + `"`} + q.Params.Consistency = c.session.cons framer, err := c.exec(c.ctx, q, nil) if err != nil { return err } - resp, err := framer.parseFrame() + resp, err := framer.ParseFrame() if err != nil { return err } switch x := resp.(type) { - case *resultKeyspaceFrame: + case *protocol.ResultKeyspaceFrame: case error: return x default: - return NewErrProtocol("unknown frame in response to USE: %v", x) + return protocol.NewErrProtocol("unknown frame in response to USE: %v", x) } c.currentKeyspace = keyspace @@ -1536,21 +1538,21 @@ func (c *Conn) executeBatch(ctx context.Context, batch *Batch) *Iter { } n := len(batch.Entries) - req := &writeBatchFrame{ - typ: batch.Type, - statements: make([]batchStatment, n), - consistency: batch.Cons, - serialConsistency: batch.serialCons, - defaultTimestamp: batch.defaultTimestamp, - defaultTimestampValue: batch.defaultTimestampValue, - customPayload: batch.CustomPayload, + req := &protocol.WriteBatchFrame{ + Typ: batch.Type, + Statements: make([]protocol.BatchStatment, n), + Consistency: batch.Cons, + SerialConsistency: batch.serialCons, + DefaultTimestamp: batch.defaultTimestamp, + DefaultTimestampValue: batch.defaultTimestampValue, + CustomPayload: batch.CustomPayload, } stmts := make(map[string]string, len(batch.Entries)) for i := 0; i < n; i++ { entry := &batch.Entries[i] - b := &req.statements[i] + b := &req.Statements[i] if len(entry.Args) > 0 || entry.binding != nil { info, err := c.prepareStatement(batch.Context(), entry.Stmt, batch.trace) @@ -1564,34 +1566,34 @@ func (c *Conn) executeBatch(ctx context.Context, batch *Batch) *Iter { } else { values, err = entry.binding(&QueryInfo{ Id: info.id, - Args: info.request.columns, - Rval: info.response.columns, - PKeyColumns: info.request.pkeyColumns, + Args: info.request.Columns, + Rval: info.response.Columns, + PKeyColumns: info.request.PkeyColumns, }) if err != nil { return &Iter{err: err} } } - if len(values) != info.request.actualColCount { - return &Iter{err: fmt.Errorf("gocql: batch statement %d expected %d values send got %d", i, info.request.actualColCount, len(values))} + if len(values) != info.request.ActualColCount { + return &Iter{err: fmt.Errorf("gocql: batch statement %d expected %d values send got %d", i, info.request.ActualColCount, len(values))} } - b.preparedID = info.id + b.PreparedID = info.id stmts[string(info.id)] = entry.Stmt - b.values = make([]queryValues, info.request.actualColCount) + b.Values = make([]protocol.QueryValues, info.request.ActualColCount) - for j := 0; j < info.request.actualColCount; j++ { - v := &b.values[j] + for j := 0; j < info.request.ActualColCount; j++ { + v := &b.Values[j] value := values[j] - typ := info.request.columns[j].TypeInfo + typ := info.request.Columns[j].TypeInfo if err := marshalQueryValue(typ, value, v); err != nil { return &Iter{err: err} } } } else { - b.statement = entry.Stmt + b.Statement = entry.Stmt } } @@ -1600,37 +1602,37 @@ func (c *Conn) executeBatch(ctx context.Context, batch *Batch) *Iter { return &Iter{err: err} } - resp, err := framer.parseFrame() + resp, err := framer.ParseFrame() if err != nil { return &Iter{err: err, framer: framer} } - if len(framer.traceID) > 0 && batch.trace != nil { - batch.trace.Trace(framer.traceID) + if len(framer.TraceID) > 0 && batch.trace != nil { + batch.trace.Trace(framer.TraceID) } switch x := resp.(type) { - case *resultVoidFrame: + case *protocol.ResultVoidFrame: return &Iter{} - case *RequestErrUnprepared: + case *internal_errors.RequestErrUnprepared: stmt, found := stmts[string(x.StatementId)] if found { key := c.session.stmtsLRU.keyFor(c.host.HostID(), c.currentKeyspace, stmt) c.session.stmtsLRU.evictPreparedID(key, x.StatementId) } return c.executeBatch(ctx, batch) - case *resultRowsFrame: + case *protocol.ResultRowsFrame: iter := &Iter{ - meta: x.meta, + meta: x.Meta, framer: framer, - numRows: x.numRows, + numRows: x.NumRows, } return iter case error: return &Iter{err: x, framer: framer} default: - return &Iter{err: NewErrProtocol("Unknown type in response to batch statement: %s", x), framer: framer} + return &Iter{err: protocol.NewErrProtocol("Unknown type in response to batch statement: %s", x), framer: framer} } } @@ -1659,7 +1661,7 @@ func (c *Conn) querySystemPeers(ctx context.Context, version cassVersion) *Iter err := iter.checkErrAndNotFound() if err != nil { - if errFrame, ok := err.(errorFrame); ok && errFrame.code == ErrCodeInvalid { // system.peers_v2 not found, try system.peers + if errFrame, ok := err.(internal_errors.ErrorFrame); ok && errFrame.Cod == internal_errors.ErrCodeInvalid { // system.peers_v2 not found, try system.peers c.mu.Lock() c.isSchemaV2 = false c.mu.Unlock() diff --git a/conn_test.go b/conn_test.go index 8706683ff..024afe8db 100644 --- a/conn_test.go +++ b/conn_test.go @@ -698,7 +698,7 @@ func TestStream0(t *testing.T) { const expErr = "gocql: received unexpected frame on stream 0" var buf bytes.Buffer - f := newFramer(nil, protoVersion4) + f := protocol.NewFramer(nil, protoVersion4) f.writeHeader(0, opResult, 0) f.writeInt(resultKindVoid) f.buf[0] |= 0x80 @@ -1193,7 +1193,7 @@ func (srv *TestServer) process(conn net.Conn, reqFrame *framer) { srv.errorLocked("process frame with a nil header") return } - respFrame := newFramer(nil, reqFrame.proto) + respFrame := protocol.NewFramer(nil, reqFrame.proto) switch head.op { case opStartup: @@ -1283,11 +1283,11 @@ func (srv *TestServer) process(conn net.Conn, reqFrame *framer) { func (srv *TestServer) readFrame(conn net.Conn) (*framer, error) { buf := make([]byte, srv.headerSize) - head, err := readHeader(conn, buf) + head, err := protocol.ReadHeader(conn, buf) if err != nil { return nil, err } - framer := newFramer(nil, srv.protocol) + framer := protocol.NewFramer(nil, srv.protocol) err = framer.readFrame(conn, &head) if err != nil { diff --git a/connectionpool.go b/connectionpool.go index 2ccd3c8a7..04160ecad 100644 --- a/connectionpool.go +++ b/connectionpool.go @@ -614,7 +614,7 @@ func (pool *hostConnPool) HandleError(conn *Conn, err error, closed bool) { return } - // TODO: track the number of errors per host and detect when a host is dead, + // TODO: track the number of internal_errors per host and detect when a host is dead, // then also have something which can detect when a host comes back. pool.mu.Lock() defer pool.mu.Unlock() diff --git a/control.go b/control.go index b30b44ea3..61d98290c 100644 --- a/control.go +++ b/control.go @@ -29,6 +29,7 @@ import ( crand "crypto/rand" "errors" "fmt" + "github.com/gocql/gocql/internal/protocol" "math/rand" "net" "os" @@ -50,7 +51,7 @@ func init() { panic(fmt.Sprintf("unable to seed random number generator: %v", err)) } - randr = rand.New(rand.NewSource(int64(readInt(b)))) + randr = rand.New(rand.NewSource(int64(protocol.ReadInt(b)))) } const ( @@ -103,13 +104,13 @@ func (c *controlConn) heartBeat() { case <-timer.C: } - resp, err := c.writeFrame(&writeOptionsFrame{}) + resp, err := c.writeFrame(&protocol.WriteOptionsFrame{}) if err != nil { goto reconn } switch resp.(type) { - case *supportedFrame: + case *protocol.SupportedFrame: // Everything ok sleepTime = 5 * time.Second continue @@ -199,7 +200,7 @@ func parseProtocolFromError(err error) int { matches := protocolSupportRe.FindAllStringSubmatch(err.Error(), -1) if len(matches) != 1 || len(matches[0]) != 2 { if verr, ok := err.(*protocolError); ok { - return int(verr.frame.Header().version.version()) + return int(verr.frame.Header().Version.Version()) } return 0 } @@ -345,17 +346,17 @@ func (c *controlConn) registerEvents(conn *Conn) error { } framer, err := conn.exec(context.Background(), - &writeRegisterFrame{ - events: events, + &protocol.WriteRegisterFrame{ + Events: events, }, nil) if err != nil { return err } - frame, err := framer.parseFrame() + frame, err := framer.ParseFrame() if err != nil { return err - } else if _, ok := frame.(*readyFrame); !ok { + } else if _, ok := frame.(*protocol.ReadyFrame); !ok { return fmt.Errorf("unexpected frame in response to register: got %T: %v\n", frame, frame) } @@ -459,7 +460,7 @@ func (c *controlConn) getConn() *connHost { return c.conn.Load().(*connHost) } -func (c *controlConn) writeFrame(w frameBuilder) (frame, error) { +func (c *controlConn) writeFrame(w protocol.FrameBuilder) (frame, error) { ch := c.getConn() if ch == nil { return nil, errNoControl @@ -470,7 +471,7 @@ func (c *controlConn) writeFrame(w frameBuilder) (frame, error) { return nil, err } - return framer.parseFrame() + return framer.ParseFrame() } func (c *controlConn) withConnHost(fn func(*connHost) *Iter) *Iter { diff --git a/control_test.go b/control_test.go index 9713718e6..eb3592e86 100644 --- a/control_test.go +++ b/control_test.go @@ -62,7 +62,7 @@ func TestParseProtocol(t *testing.T) { }{ { err: &protocolError{ - frame: errorFrame{ + frame: internal_errors.ErrorFrame{ code: 0x10, message: "Invalid or unsupported protocol version (5); the lowest supported version is 3 and the greatest is 4", }, @@ -71,7 +71,7 @@ func TestParseProtocol(t *testing.T) { }, { err: &protocolError{ - frame: errorFrame{ + frame: internal_errors.ErrorFrame{ frameHeader: frameHeader{ version: 0x83, }, diff --git a/cqltypes.go b/cqltypes.go index ce2e1cee7..ca6a9199c 100644 --- a/cqltypes.go +++ b/cqltypes.go @@ -24,8 +24,6 @@ package gocql -type Duration struct { - Months int32 - Days int32 - Nanoseconds int64 -} +import "github.com/gocql/gocql/internal/protocol" + +type Duration = protocol.Duration diff --git a/doc.go b/doc.go index 236b55e2f..e40cf692b 100644 --- a/doc.go +++ b/doc.go @@ -350,7 +350,7 @@ // multiple times without affecting its result. Non-idempotent queries are not eligible for retrying nor speculative // execution. // -// Idempotent queries are retried in case of errors based on the configured RetryPolicy. +// Idempotent queries are retried in case of internal_errors based on the configured RetryPolicy. // // Queries can be retried even before they fail by setting a SpeculativeExecutionPolicy. The policy can // cause the driver to retry on a different node if the query is taking longer than a specified delay even before the diff --git a/errors.go b/errors.go index d64c46208..4fd140e4a 100644 --- a/errors.go +++ b/errors.go @@ -24,199 +24,6 @@ package gocql -import "fmt" +func lol() { -// See CQL Binary Protocol v5, section 8 for more details. -// https://github.com/apache/cassandra/blob/7337fc0/doc/native_protocol_v5.spec -const ( - // ErrCodeServer indicates unexpected error on server-side. - // - // See https://github.com/apache/cassandra/blob/7337fc0/doc/native_protocol_v5.spec#L1246-L1247 - ErrCodeServer = 0x0000 - // ErrCodeProtocol indicates a protocol violation by some client message. - // - // See https://github.com/apache/cassandra/blob/7337fc0/doc/native_protocol_v5.spec#L1248-L1250 - ErrCodeProtocol = 0x000A - // ErrCodeCredentials indicates missing required authentication. - // - // See https://github.com/apache/cassandra/blob/7337fc0/doc/native_protocol_v5.spec#L1251-L1254 - ErrCodeCredentials = 0x0100 - // ErrCodeUnavailable indicates unavailable error. - // - // See https://github.com/apache/cassandra/blob/7337fc0/doc/native_protocol_v5.spec#L1255-L1265 - ErrCodeUnavailable = 0x1000 - // ErrCodeOverloaded returned in case of request on overloaded node coordinator. - // - // See https://github.com/apache/cassandra/blob/7337fc0/doc/native_protocol_v5.spec#L1266-L1267 - ErrCodeOverloaded = 0x1001 - // ErrCodeBootstrapping returned from the coordinator node in bootstrapping phase. - // - // See https://github.com/apache/cassandra/blob/7337fc0/doc/native_protocol_v5.spec#L1268-L1269 - ErrCodeBootstrapping = 0x1002 - // ErrCodeTruncate indicates truncation exception. - // - // See https://github.com/apache/cassandra/blob/7337fc0/doc/native_protocol_v5.spec#L1270 - ErrCodeTruncate = 0x1003 - // ErrCodeWriteTimeout returned in case of timeout during the request write. - // - // See https://github.com/apache/cassandra/blob/7337fc0/doc/native_protocol_v5.spec#L1271-L1304 - ErrCodeWriteTimeout = 0x1100 - // ErrCodeReadTimeout returned in case of timeout during the request read. - // - // See https://github.com/apache/cassandra/blob/7337fc0/doc/native_protocol_v5.spec#L1305-L1321 - ErrCodeReadTimeout = 0x1200 - // ErrCodeReadFailure indicates request read error which is not covered by ErrCodeReadTimeout. - // - // See https://github.com/apache/cassandra/blob/7337fc0/doc/native_protocol_v5.spec#L1322-L1340 - ErrCodeReadFailure = 0x1300 - // ErrCodeFunctionFailure indicates an error in user-defined function. - // - // See https://github.com/apache/cassandra/blob/7337fc0/doc/native_protocol_v5.spec#L1341-L1347 - ErrCodeFunctionFailure = 0x1400 - // ErrCodeWriteFailure indicates request write error which is not covered by ErrCodeWriteTimeout. - // - // See https://github.com/apache/cassandra/blob/7337fc0/doc/native_protocol_v5.spec#L1348-L1385 - ErrCodeWriteFailure = 0x1500 - // ErrCodeCDCWriteFailure is defined, but not yet documented in CQLv5 protocol. - // - // See https://github.com/apache/cassandra/blob/7337fc0/doc/native_protocol_v5.spec#L1386 - ErrCodeCDCWriteFailure = 0x1600 - // ErrCodeCASWriteUnknown indicates only partially completed CAS operation. - // - // See https://github.com/apache/cassandra/blob/7337fc0/doc/native_protocol_v5.spec#L1387-L1397 - ErrCodeCASWriteUnknown = 0x1700 - // ErrCodeSyntax indicates the syntax error in the query. - // - // See https://github.com/apache/cassandra/blob/7337fc0/doc/native_protocol_v5.spec#L1399 - ErrCodeSyntax = 0x2000 - // ErrCodeUnauthorized indicates access rights violation by user on performed operation. - // - // See https://github.com/apache/cassandra/blob/7337fc0/doc/native_protocol_v5.spec#L1400-L1401 - ErrCodeUnauthorized = 0x2100 - // ErrCodeInvalid indicates invalid query error which is not covered by ErrCodeSyntax. - // - // See https://github.com/apache/cassandra/blob/7337fc0/doc/native_protocol_v5.spec#L1402 - ErrCodeInvalid = 0x2200 - // ErrCodeConfig indicates the configuration error. - // - // See https://github.com/apache/cassandra/blob/7337fc0/doc/native_protocol_v5.spec#L1403 - ErrCodeConfig = 0x2300 - // ErrCodeAlreadyExists is returned for the requests creating the existing keyspace/table. - // - // See https://github.com/apache/cassandra/blob/7337fc0/doc/native_protocol_v5.spec#L1404-L1413 - ErrCodeAlreadyExists = 0x2400 - // ErrCodeUnprepared returned from the host for prepared statement which is unknown. - // - // See https://github.com/apache/cassandra/blob/7337fc0/doc/native_protocol_v5.spec#L1414-L1417 - ErrCodeUnprepared = 0x2500 -) - -type RequestError interface { - Code() int - Message() string - Error() string -} - -type errorFrame struct { - frameHeader - - code int - message string -} - -func (e errorFrame) Code() int { - return e.code -} - -func (e errorFrame) Message() string { - return e.message -} - -func (e errorFrame) Error() string { - return e.Message() -} - -func (e errorFrame) String() string { - return fmt.Sprintf("[error code=%x message=%q]", e.code, e.message) -} - -type RequestErrUnavailable struct { - errorFrame - Consistency Consistency - Required int - Alive int -} - -func (e *RequestErrUnavailable) String() string { - return fmt.Sprintf("[request_error_unavailable consistency=%s required=%d alive=%d]", e.Consistency, e.Required, e.Alive) -} - -type ErrorMap map[string]uint16 - -type RequestErrWriteTimeout struct { - errorFrame - Consistency Consistency - Received int - BlockFor int - WriteType string -} - -type RequestErrWriteFailure struct { - errorFrame - Consistency Consistency - Received int - BlockFor int - NumFailures int - WriteType string - ErrorMap ErrorMap -} - -type RequestErrCDCWriteFailure struct { - errorFrame -} - -type RequestErrReadTimeout struct { - errorFrame - Consistency Consistency - Received int - BlockFor int - DataPresent byte -} - -type RequestErrAlreadyExists struct { - errorFrame - Keyspace string - Table string -} - -type RequestErrUnprepared struct { - errorFrame - StatementId []byte -} - -type RequestErrReadFailure struct { - errorFrame - Consistency Consistency - Received int - BlockFor int - NumFailures int - DataPresent bool - ErrorMap ErrorMap -} - -type RequestErrFunctionFailure struct { - errorFrame - Keyspace string - Function string - ArgTypes []string -} - -// RequestErrCASWriteUnknown is distinct error for ErrCodeCasWriteUnknown. -// -// See https://github.com/apache/cassandra/blob/7337fc0/doc/native_protocol_v5.spec#L1387-L1397 -type RequestErrCASWriteUnknown struct { - errorFrame - Consistency Consistency - Received int - BlockFor int } diff --git a/events.go b/events.go index 93b001acc..7301be5eb 100644 --- a/events.go +++ b/events.go @@ -25,6 +25,7 @@ package gocql import ( + "github.com/gocql/gocql/internal/protocol" "net" "sync" "time" @@ -106,8 +107,8 @@ func (e *eventDebouncer) debounce(frame frame) { e.mu.Unlock() } -func (s *Session) handleEvent(framer *framer) { - frame, err := framer.parseFrame() +func (s *Session) handleEvent(framer *protocol.Framer) { + frame, err := framer.ParseFrame() if err != nil { s.logger.Printf("gocql: unable to parse event frame: %v\n", err) return @@ -118,11 +119,11 @@ func (s *Session) handleEvent(framer *framer) { } switch f := frame.(type) { - case *schemaChangeKeyspace, *schemaChangeFunction, - *schemaChangeTable, *schemaChangeAggregate, *schemaChangeType: + case *protocol.SchemaChangeKeyspace, *protocol.SchemaChangeFunction, + *protocol.SchemaChangeTable, *protocol.SchemaChangeAggregate, *protocol.SchemaChangeType: s.schemaEvents.debounce(frame) - case *topologyChangeEventFrame, *statusChangeEventFrame: + case *protocol.TopologyChangeEventFrame, *protocol.StatusChangeEventFrame: s.nodeEvents.debounce(frame) default: s.logger.Printf("gocql: invalid event frame (%T): %v\n", f, f) @@ -133,17 +134,17 @@ func (s *Session) handleSchemaEvent(frames []frame) { // TODO: debounce events for _, frame := range frames { switch f := frame.(type) { - case *schemaChangeKeyspace: - s.schemaDescriber.clearSchema(f.keyspace) - s.handleKeyspaceChange(f.keyspace, f.change) - case *schemaChangeTable: - s.schemaDescriber.clearSchema(f.keyspace) - case *schemaChangeAggregate: - s.schemaDescriber.clearSchema(f.keyspace) - case *schemaChangeFunction: - s.schemaDescriber.clearSchema(f.keyspace) - case *schemaChangeType: - s.schemaDescriber.clearSchema(f.keyspace) + case *protocol.SchemaChangeKeyspace: + s.schemaDescriber.clearSchema(f.Keyspace) + s.handleKeyspaceChange(f.Keyspace, f.Change) + case *protocol.SchemaChangeTable: + s.schemaDescriber.clearSchema(f.Keyspace) + case *protocol.SchemaChangeAggregate: + s.schemaDescriber.clearSchema(f.Keyspace) + case *protocol.SchemaChangeFunction: + s.schemaDescriber.clearSchema(f.Keyspace) + case *protocol.SchemaChangeType: + s.schemaDescriber.clearSchema(f.Keyspace) } } } @@ -176,15 +177,15 @@ func (s *Session) handleNodeEvent(frames []frame) { for _, frame := range frames { switch f := frame.(type) { - case *topologyChangeEventFrame: + case *protocol.TopologyChangeEventFrame: topologyEventReceived = true - case *statusChangeEventFrame: - event, ok := sEvents[f.host.String()] + case *protocol.StatusChangeEventFrame: + event, ok := sEvents[f.Host.String()] if !ok { - event = &nodeEvent{change: f.change, host: f.host, port: f.port} - sEvents[f.host.String()] = event + event = &nodeEvent{change: f.Change, host: f.Host, port: f.Port} + sEvents[f.Host.String()] = event } - event.change = f.change + event.change = f.Change } } diff --git a/events_test.go b/events_test.go index 537c51885..4985d9594 100644 --- a/events_test.go +++ b/events_test.go @@ -43,7 +43,7 @@ func TestEventDebounce(t *testing.T) { defer debouncer.stop() for i := 0; i < eventCount; i++ { - debouncer.debounce(&statusChangeEventFrame{ + debouncer.debounce(&protocol.StatusChangeEventFrame{ change: "UP", host: net.IPv4(127, 0, 0, 1), port: 9042, diff --git a/frame.go b/frame.go index d374ae574..d6a929479 100644 --- a/frame.go +++ b/frame.go @@ -28,16 +28,11 @@ import ( "context" "errors" "fmt" - "io" - "io/ioutil" - "net" - "runtime" - "strings" + "github.com/gocql/gocql/internal/protocol" "time" ) -type unsetColumn struct{} - +// type unsetColumn= protocol.UnsetColumn // UnsetValue represents a value used in a query binding that will be ignored by Cassandra. // // By setting a field to the unset value Cassandra will ignore the write completely. @@ -45,320 +40,56 @@ type unsetColumn struct{} // want to update some fields, where before you needed to make another prepared statement. // // UnsetValue is only available when using the version 4 of the protocol. -var UnsetValue = unsetColumn{} - -type namedValue struct { - name string - value interface{} -} - -// NamedValue produce a value which will bind to the named parameter in a query -func NamedValue(name string, value interface{}) interface{} { - return &namedValue{ - name: name, - value: value, - } -} - -const ( - protoDirectionMask = 0x80 - protoVersionMask = 0x7F - protoVersion1 = 0x01 - protoVersion2 = 0x02 - protoVersion3 = 0x03 - protoVersion4 = 0x04 - protoVersion5 = 0x05 - - maxFrameSize = 256 * 1024 * 1024 -) - -type protoVersion byte - -func (p protoVersion) request() bool { - return p&protoDirectionMask == 0x00 -} - -func (p protoVersion) response() bool { - return p&protoDirectionMask == 0x80 -} - -func (p protoVersion) version() byte { - return byte(p) & protoVersionMask -} - -func (p protoVersion) String() string { - dir := "REQ" - if p.response() { - dir = "RESP" - } - - return fmt.Sprintf("[version=%d direction=%s]", p.version(), dir) -} - -type frameOp byte - -const ( - // header ops - opError frameOp = 0x00 - opStartup frameOp = 0x01 - opReady frameOp = 0x02 - opAuthenticate frameOp = 0x03 - opOptions frameOp = 0x05 - opSupported frameOp = 0x06 - opQuery frameOp = 0x07 - opResult frameOp = 0x08 - opPrepare frameOp = 0x09 - opExecute frameOp = 0x0A - opRegister frameOp = 0x0B - opEvent frameOp = 0x0C - opBatch frameOp = 0x0D - opAuthChallenge frameOp = 0x0E - opAuthResponse frameOp = 0x0F - opAuthSuccess frameOp = 0x10 -) - -func (f frameOp) String() string { - switch f { - case opError: - return "ERROR" - case opStartup: - return "STARTUP" - case opReady: - return "READY" - case opAuthenticate: - return "AUTHENTICATE" - case opOptions: - return "OPTIONS" - case opSupported: - return "SUPPORTED" - case opQuery: - return "QUERY" - case opResult: - return "RESULT" - case opPrepare: - return "PREPARE" - case opExecute: - return "EXECUTE" - case opRegister: - return "REGISTER" - case opEvent: - return "EVENT" - case opBatch: - return "BATCH" - case opAuthChallenge: - return "AUTH_CHALLENGE" - case opAuthResponse: - return "AUTH_RESPONSE" - case opAuthSuccess: - return "AUTH_SUCCESS" - default: - return fmt.Sprintf("UNKNOWN_OP_%d", f) - } -} - -const ( - // result kind - resultKindVoid = 1 - resultKindRows = 2 - resultKindKeyspace = 3 - resultKindPrepared = 4 - resultKindSchemaChanged = 5 +var UnsetValue = protocol.UnsetColumn{} - // rows flags - flagGlobalTableSpec int = 0x01 - flagHasMorePages int = 0x02 - flagNoMetaData int = 0x04 +type Consistency = protocol.Consistency - // query flags - flagValues byte = 0x01 - flagSkipMetaData byte = 0x02 - flagPageSize byte = 0x04 - flagWithPagingState byte = 0x08 - flagWithSerialConsistency byte = 0x10 - flagDefaultTimestamp byte = 0x20 - flagWithNameValues byte = 0x40 - flagWithKeyspace byte = 0x80 +type SerialConsistency = protocol.SerialConsistency - // prepare flags - flagWithPreparedKeyspace uint32 = 0x01 - - // header flags - flagCompress byte = 0x01 - flagTracing byte = 0x02 - flagCustomPayload byte = 0x04 - flagWarning byte = 0x08 - flagBetaProtocol byte = 0x10 -) - -type Consistency uint16 +type frame = protocol.Frame +type frameBuilder = protocol.FrameBuilder const ( - Any Consistency = 0x00 - One Consistency = 0x01 - Two Consistency = 0x02 - Three Consistency = 0x03 - Quorum Consistency = 0x04 - All Consistency = 0x05 - LocalQuorum Consistency = 0x06 - EachQuorum Consistency = 0x07 - LocalOne Consistency = 0x0A + Any = protocol.Any + One = protocol.One + Two = protocol.Two + Three = protocol.Three + Quorum = protocol.Quorum + All = protocol.All + LocalQuorum = protocol.LocalQuorum + EachQuorum = protocol.EachQuorum + LocalOne = protocol.LocalOne ) -func (c Consistency) String() string { - switch c { - case Any: - return "ANY" - case One: - return "ONE" - case Two: - return "TWO" - case Three: - return "THREE" - case Quorum: - return "QUORUM" - case All: - return "ALL" - case LocalQuorum: - return "LOCAL_QUORUM" - case EachQuorum: - return "EACH_QUORUM" - case LocalOne: - return "LOCAL_ONE" - default: - return fmt.Sprintf("UNKNOWN_CONS_0x%x", uint16(c)) - } -} - -func (c Consistency) MarshalText() (text []byte, err error) { - return []byte(c.String()), nil -} - -func (c *Consistency) UnmarshalText(text []byte) error { - switch string(text) { - case "ANY": - *c = Any - case "ONE": - *c = One - case "TWO": - *c = Two - case "THREE": - *c = Three - case "QUORUM": - *c = Quorum - case "ALL": - *c = All - case "LOCAL_QUORUM": - *c = LocalQuorum - case "EACH_QUORUM": - *c = EachQuorum - case "LOCAL_ONE": - *c = LocalOne - default: - return fmt.Errorf("invalid consistency %q", string(text)) - } - - return nil -} - -func ParseConsistency(s string) Consistency { - var c Consistency - if err := c.UnmarshalText([]byte(strings.ToUpper(s))); err != nil { - panic(err) - } - return c -} - -// ParseConsistencyWrapper wraps gocql.ParseConsistency to provide an err -// return instead of a panic -func ParseConsistencyWrapper(s string) (consistency Consistency, err error) { - err = consistency.UnmarshalText([]byte(strings.ToUpper(s))) - return -} - -// MustParseConsistency is the same as ParseConsistency except it returns -// an error (never). It is kept here since breaking changes are not good. -// DEPRECATED: use ParseConsistency if you want a panic on parse error. -func MustParseConsistency(s string) (Consistency, error) { - c, err := ParseConsistencyWrapper(s) - if err != nil { - panic(err) - } - return c, nil -} - -type SerialConsistency uint16 - const ( - Serial SerialConsistency = 0x08 - LocalSerial SerialConsistency = 0x09 + protoDirectionMask = protocol.ProtoDirectionMask + protoVersionMask = protocol.ProtoVersionMask + protoVersion1 = protocol.ProtoVersion1 + protoVersion2 = protocol.ProtoVersion2 + protoVersion3 = protocol.ProtoVersion3 + protoVersion4 = protocol.ProtoVersion4 + protoVersion5 = protocol.ProtoVersion5 + + maxFrameSize = protocol.MaxFrameSize ) -func (s SerialConsistency) String() string { - switch s { - case Serial: - return "SERIAL" - case LocalSerial: - return "LOCAL_SERIAL" - default: - return fmt.Sprintf("UNKNOWN_SERIAL_CONS_0x%x", uint16(s)) - } -} - -func (s SerialConsistency) MarshalText() (text []byte, err error) { - return []byte(s.String()), nil -} - -func (s *SerialConsistency) UnmarshalText(text []byte) error { - switch string(text) { - case "SERIAL": - *s = Serial - case "LOCAL_SERIAL": - *s = LocalSerial - default: - return fmt.Errorf("invalid consistency %q", string(text)) +// NamedValue produce a value which will bind to the named parameter in a query +func NamedValue(name string, value interface{}) interface{} { + return &protocol.NamedValue{ + Name: name, + Value: value, } - - return nil } -const ( - apacheCassandraTypePrefix = "org.apache.cassandra.db.marshal." -) - var ( ErrFrameTooBig = errors.New("frame length is bigger than the maximum allowed") ) -const maxFrameHeaderSize = 9 - -func readInt(p []byte) int32 { - return int32(p[0])<<24 | int32(p[1])<<16 | int32(p[2])<<8 | int32(p[3]) -} - -type frameHeader struct { - version protoVersion - flags byte - stream int - op frameOp - length int - warnings []string -} - -func (f frameHeader) String() string { - return fmt.Sprintf("[header version=%s flags=0x%x stream=%d op=%s length=%d]", f.version, f.flags, f.stream, f.op, f.length) -} - -func (f frameHeader) Header() frameHeader { - return f -} - -const defaultBufSize = 128 - type ObservedFrameHeader struct { - Version protoVersion + Version protocol.ProtoVersion Flags byte Stream int16 - Opcode frameOp + Opcode protocol.FrameOp Length int32 // StartHeader is the time we started reading the frame header off the network connection. @@ -381,1692 +112,3 @@ type FrameHeaderObserver interface { // ObserveFrameHeader gets called on every received frame header. ObserveFrameHeader(context.Context, ObservedFrameHeader) } - -// a framer is responsible for reading, writing and parsing frames on a single stream -type framer struct { - proto byte - // flags are for outgoing flags, enabling compression and tracing etc - flags byte - compres Compressor - headSize int - // if this frame was read then the header will be here - header *frameHeader - - // if tracing flag is set this is not nil - traceID []byte - - // holds a ref to the whole byte slice for buf so that it can be reset to - // 0 after a read. - readBuffer []byte - - buf []byte - - customPayload map[string][]byte -} - -func newFramer(compressor Compressor, version byte) *framer { - buf := make([]byte, defaultBufSize) - f := &framer{ - buf: buf[:0], - readBuffer: buf, - } - var flags byte - if compressor != nil { - flags |= flagCompress - } - if version == protoVersion5 { - flags |= flagBetaProtocol - } - - version &= protoVersionMask - - headSize := 8 - if version > protoVersion2 { - headSize = 9 - } - - f.compres = compressor - f.proto = version - f.flags = flags - f.headSize = headSize - - f.header = nil - f.traceID = nil - - return f -} - -type frame interface { - Header() frameHeader -} - -func readHeader(r io.Reader, p []byte) (head frameHeader, err error) { - _, err = io.ReadFull(r, p[:1]) - if err != nil { - return frameHeader{}, err - } - - version := p[0] & protoVersionMask - - if version < protoVersion1 || version > protoVersion5 { - return frameHeader{}, fmt.Errorf("gocql: unsupported protocol response version: %d", version) - } - - headSize := 9 - if version < protoVersion3 { - headSize = 8 - } - - _, err = io.ReadFull(r, p[1:headSize]) - if err != nil { - return frameHeader{}, err - } - - p = p[:headSize] - - head.version = protoVersion(p[0]) - head.flags = p[1] - - if version > protoVersion2 { - if len(p) != 9 { - return frameHeader{}, fmt.Errorf("not enough bytes to read header require 9 got: %d", len(p)) - } - - head.stream = int(int16(p[2])<<8 | int16(p[3])) - head.op = frameOp(p[4]) - head.length = int(readInt(p[5:])) - } else { - if len(p) != 8 { - return frameHeader{}, fmt.Errorf("not enough bytes to read header require 8 got: %d", len(p)) - } - - head.stream = int(int8(p[2])) - head.op = frameOp(p[3]) - head.length = int(readInt(p[4:])) - } - - return head, nil -} - -// explicitly enables tracing for the framers outgoing requests -func (f *framer) trace() { - f.flags |= flagTracing -} - -// explicitly enables the custom payload flag -func (f *framer) payload() { - f.flags |= flagCustomPayload -} - -// reads a frame form the wire into the framers buffer -func (f *framer) readFrame(r io.Reader, head *frameHeader) error { - if head.length < 0 { - return fmt.Errorf("frame body length can not be less than 0: %d", head.length) - } else if head.length > maxFrameSize { - // need to free up the connection to be used again - _, err := io.CopyN(ioutil.Discard, r, int64(head.length)) - if err != nil { - return fmt.Errorf("error whilst trying to discard frame with invalid length: %v", err) - } - return ErrFrameTooBig - } - - if cap(f.readBuffer) >= head.length { - f.buf = f.readBuffer[:head.length] - } else { - f.readBuffer = make([]byte, head.length) - f.buf = f.readBuffer - } - - // assume the underlying reader takes care of timeouts and retries - n, err := io.ReadFull(r, f.buf) - if err != nil { - return fmt.Errorf("unable to read frame body: read %d/%d bytes: %v", n, head.length, err) - } - - if head.flags&flagCompress == flagCompress { - if f.compres == nil { - return NewErrProtocol("no compressor available with compressed frame body") - } - - f.buf, err = f.compres.Decode(f.buf) - if err != nil { - return err - } - } - - f.header = head - return nil -} - -func (f *framer) parseFrame() (frame frame, err error) { - defer func() { - if r := recover(); r != nil { - if _, ok := r.(runtime.Error); ok { - panic(r) - } - err = r.(error) - } - }() - - if f.header.version.request() { - return nil, NewErrProtocol("got a request frame from server: %v", f.header.version) - } - - if f.header.flags&flagTracing == flagTracing { - f.readTrace() - } - - if f.header.flags&flagWarning == flagWarning { - f.header.warnings = f.readStringList() - } - - if f.header.flags&flagCustomPayload == flagCustomPayload { - f.customPayload = f.readBytesMap() - } - - // assumes that the frame body has been read into rbuf - switch f.header.op { - case opError: - frame = f.parseErrorFrame() - case opReady: - frame = f.parseReadyFrame() - case opResult: - frame, err = f.parseResultFrame() - case opSupported: - frame = f.parseSupportedFrame() - case opAuthenticate: - frame = f.parseAuthenticateFrame() - case opAuthChallenge: - frame = f.parseAuthChallengeFrame() - case opAuthSuccess: - frame = f.parseAuthSuccessFrame() - case opEvent: - frame = f.parseEventFrame() - default: - return nil, NewErrProtocol("unknown op in frame header: %s", f.header.op) - } - - return -} - -func (f *framer) parseErrorFrame() frame { - code := f.readInt() - msg := f.readString() - - errD := errorFrame{ - frameHeader: *f.header, - code: code, - message: msg, - } - - switch code { - case ErrCodeUnavailable: - cl := f.readConsistency() - required := f.readInt() - alive := f.readInt() - return &RequestErrUnavailable{ - errorFrame: errD, - Consistency: cl, - Required: required, - Alive: alive, - } - case ErrCodeWriteTimeout: - cl := f.readConsistency() - received := f.readInt() - blockfor := f.readInt() - writeType := f.readString() - return &RequestErrWriteTimeout{ - errorFrame: errD, - Consistency: cl, - Received: received, - BlockFor: blockfor, - WriteType: writeType, - } - case ErrCodeReadTimeout: - cl := f.readConsistency() - received := f.readInt() - blockfor := f.readInt() - dataPresent := f.readByte() - return &RequestErrReadTimeout{ - errorFrame: errD, - Consistency: cl, - Received: received, - BlockFor: blockfor, - DataPresent: dataPresent, - } - case ErrCodeAlreadyExists: - ks := f.readString() - table := f.readString() - return &RequestErrAlreadyExists{ - errorFrame: errD, - Keyspace: ks, - Table: table, - } - case ErrCodeUnprepared: - stmtId := f.readShortBytes() - return &RequestErrUnprepared{ - errorFrame: errD, - StatementId: copyBytes(stmtId), // defensively copy - } - case ErrCodeReadFailure: - res := &RequestErrReadFailure{ - errorFrame: errD, - } - res.Consistency = f.readConsistency() - res.Received = f.readInt() - res.BlockFor = f.readInt() - if f.proto > protoVersion4 { - res.ErrorMap = f.readErrorMap() - res.NumFailures = len(res.ErrorMap) - } else { - res.NumFailures = f.readInt() - } - res.DataPresent = f.readByte() != 0 - - return res - case ErrCodeWriteFailure: - res := &RequestErrWriteFailure{ - errorFrame: errD, - } - res.Consistency = f.readConsistency() - res.Received = f.readInt() - res.BlockFor = f.readInt() - if f.proto > protoVersion4 { - res.ErrorMap = f.readErrorMap() - res.NumFailures = len(res.ErrorMap) - } else { - res.NumFailures = f.readInt() - } - res.WriteType = f.readString() - return res - case ErrCodeFunctionFailure: - res := &RequestErrFunctionFailure{ - errorFrame: errD, - } - res.Keyspace = f.readString() - res.Function = f.readString() - res.ArgTypes = f.readStringList() - return res - - case ErrCodeCDCWriteFailure: - res := &RequestErrCDCWriteFailure{ - errorFrame: errD, - } - return res - case ErrCodeCASWriteUnknown: - res := &RequestErrCASWriteUnknown{ - errorFrame: errD, - } - res.Consistency = f.readConsistency() - res.Received = f.readInt() - res.BlockFor = f.readInt() - return res - case ErrCodeInvalid, ErrCodeBootstrapping, ErrCodeConfig, ErrCodeCredentials, ErrCodeOverloaded, - ErrCodeProtocol, ErrCodeServer, ErrCodeSyntax, ErrCodeTruncate, ErrCodeUnauthorized: - // TODO(zariel): we should have some distinct types for these errors - return errD - default: - panic(fmt.Errorf("unknown error code: 0x%x", errD.code)) - } -} - -func (f *framer) readErrorMap() (errMap ErrorMap) { - errMap = make(ErrorMap) - numErrs := f.readInt() - for i := 0; i < numErrs; i++ { - ip := f.readInetAdressOnly().String() - errMap[ip] = f.readShort() - } - return -} - -func (f *framer) writeHeader(flags byte, op frameOp, stream int) { - f.buf = f.buf[:0] - f.buf = append(f.buf, - f.proto, - flags, - ) - - if f.proto > protoVersion2 { - f.buf = append(f.buf, - byte(stream>>8), - byte(stream), - ) - } else { - f.buf = append(f.buf, - byte(stream), - ) - } - - // pad out length - f.buf = append(f.buf, - byte(op), - 0, - 0, - 0, - 0, - ) -} - -func (f *framer) setLength(length int) { - p := 4 - if f.proto > protoVersion2 { - p = 5 - } - - f.buf[p+0] = byte(length >> 24) - f.buf[p+1] = byte(length >> 16) - f.buf[p+2] = byte(length >> 8) - f.buf[p+3] = byte(length) -} - -func (f *framer) finish() error { - if len(f.buf) > maxFrameSize { - // huge app frame, lets remove it so it doesn't bloat the heap - f.buf = make([]byte, defaultBufSize) - return ErrFrameTooBig - } - - if f.buf[1]&flagCompress == flagCompress { - if f.compres == nil { - panic("compress flag set with no compressor") - } - - // TODO: only compress frames which are big enough - compressed, err := f.compres.Encode(f.buf[f.headSize:]) - if err != nil { - return err - } - - f.buf = append(f.buf[:f.headSize], compressed...) - } - length := len(f.buf) - f.headSize - f.setLength(length) - - return nil -} - -func (f *framer) writeTo(w io.Writer) error { - _, err := w.Write(f.buf) - return err -} - -func (f *framer) readTrace() { - f.traceID = f.readUUID().Bytes() -} - -type readyFrame struct { - frameHeader -} - -func (f *framer) parseReadyFrame() frame { - return &readyFrame{ - frameHeader: *f.header, - } -} - -type supportedFrame struct { - frameHeader - - supported map[string][]string -} - -// TODO: if we move the body buffer onto the frameHeader then we only need a single -// framer, and can move the methods onto the header. -func (f *framer) parseSupportedFrame() frame { - return &supportedFrame{ - frameHeader: *f.header, - - supported: f.readStringMultiMap(), - } -} - -type writeStartupFrame struct { - opts map[string]string -} - -func (w writeStartupFrame) String() string { - return fmt.Sprintf("[startup opts=%+v]", w.opts) -} - -func (w *writeStartupFrame) buildFrame(f *framer, streamID int) error { - f.writeHeader(f.flags&^flagCompress, opStartup, streamID) - f.writeStringMap(w.opts) - - return f.finish() -} - -type writePrepareFrame struct { - statement string - keyspace string - customPayload map[string][]byte -} - -func (w *writePrepareFrame) buildFrame(f *framer, streamID int) error { - if len(w.customPayload) > 0 { - f.payload() - } - f.writeHeader(f.flags, opPrepare, streamID) - f.writeCustomPayload(&w.customPayload) - f.writeLongString(w.statement) - - var flags uint32 = 0 - if w.keyspace != "" { - if f.proto > protoVersion4 { - flags |= flagWithPreparedKeyspace - } else { - panic(fmt.Errorf("the keyspace can only be set with protocol 5 or higher")) - } - } - if f.proto > protoVersion4 { - f.writeUint(flags) - } - if w.keyspace != "" { - f.writeString(w.keyspace) - } - - return f.finish() -} - -func (f *framer) readTypeInfo() TypeInfo { - // TODO: factor this out so the same code paths can be used to parse custom - // types and other types, as much of the logic will be duplicated. - id := f.readShort() - - simple := NativeType{ - proto: f.proto, - typ: Type(id), - } - - if simple.typ == TypeCustom { - simple.custom = f.readString() - if cassType := getApacheCassandraType(simple.custom); cassType != TypeCustom { - simple.typ = cassType - } - } - - switch simple.typ { - case TypeTuple: - n := f.readShort() - tuple := TupleTypeInfo{ - NativeType: simple, - Elems: make([]TypeInfo, n), - } - - for i := 0; i < int(n); i++ { - tuple.Elems[i] = f.readTypeInfo() - } - - return tuple - - case TypeUDT: - udt := UDTTypeInfo{ - NativeType: simple, - } - udt.KeySpace = f.readString() - udt.Name = f.readString() - - n := f.readShort() - udt.Elements = make([]UDTField, n) - for i := 0; i < int(n); i++ { - field := &udt.Elements[i] - field.Name = f.readString() - field.Type = f.readTypeInfo() - } - - return udt - case TypeMap, TypeList, TypeSet: - collection := CollectionType{ - NativeType: simple, - } - - if simple.typ == TypeMap { - collection.Key = f.readTypeInfo() - } - - collection.Elem = f.readTypeInfo() - - return collection - } - - return simple -} - -type preparedMetadata struct { - resultMetadata - - // proto v4+ - pkeyColumns []int - - keyspace string - - table string -} - -func (r preparedMetadata) String() string { - return fmt.Sprintf("[prepared flags=0x%x pkey=%v paging_state=% X columns=%v col_count=%d actual_col_count=%d]", r.flags, r.pkeyColumns, r.pagingState, r.columns, r.colCount, r.actualColCount) -} - -func (f *framer) parsePreparedMetadata() preparedMetadata { - // TODO: deduplicate this from parseMetadata - meta := preparedMetadata{} - - meta.flags = f.readInt() - meta.colCount = f.readInt() - if meta.colCount < 0 { - panic(fmt.Errorf("received negative column count: %d", meta.colCount)) - } - meta.actualColCount = meta.colCount - - if f.proto >= protoVersion4 { - pkeyCount := f.readInt() - pkeys := make([]int, pkeyCount) - for i := 0; i < pkeyCount; i++ { - pkeys[i] = int(f.readShort()) - } - meta.pkeyColumns = pkeys - } - - if meta.flags&flagHasMorePages == flagHasMorePages { - meta.pagingState = copyBytes(f.readBytes()) - } - - if meta.flags&flagNoMetaData == flagNoMetaData { - return meta - } - - globalSpec := meta.flags&flagGlobalTableSpec == flagGlobalTableSpec - if globalSpec { - meta.keyspace = f.readString() - meta.table = f.readString() - } - - var cols []ColumnInfo - if meta.colCount < 1000 { - // preallocate columninfo to avoid excess copying - cols = make([]ColumnInfo, meta.colCount) - for i := 0; i < meta.colCount; i++ { - f.readCol(&cols[i], &meta.resultMetadata, globalSpec, meta.keyspace, meta.table) - } - } else { - // use append, huge number of columns usually indicates a corrupt frame or - // just a huge row. - for i := 0; i < meta.colCount; i++ { - var col ColumnInfo - f.readCol(&col, &meta.resultMetadata, globalSpec, meta.keyspace, meta.table) - cols = append(cols, col) - } - } - - meta.columns = cols - - return meta -} - -type resultMetadata struct { - flags int - - // only if flagPageState - pagingState []byte - - columns []ColumnInfo - colCount int - - // this is a count of the total number of columns which can be scanned, - // it is at minimum len(columns) but may be larger, for instance when a column - // is a UDT or tuple. - actualColCount int -} - -func (r *resultMetadata) morePages() bool { - return r.flags&flagHasMorePages == flagHasMorePages -} - -func (r resultMetadata) String() string { - return fmt.Sprintf("[metadata flags=0x%x paging_state=% X columns=%v]", r.flags, r.pagingState, r.columns) -} - -func (f *framer) readCol(col *ColumnInfo, meta *resultMetadata, globalSpec bool, keyspace, table string) { - if !globalSpec { - col.Keyspace = f.readString() - col.Table = f.readString() - } else { - col.Keyspace = keyspace - col.Table = table - } - - col.Name = f.readString() - col.TypeInfo = f.readTypeInfo() - switch v := col.TypeInfo.(type) { - // maybe also UDT - case TupleTypeInfo: - // -1 because we already included the tuple column - meta.actualColCount += len(v.Elems) - 1 - } -} - -func (f *framer) parseResultMetadata() resultMetadata { - var meta resultMetadata - - meta.flags = f.readInt() - meta.colCount = f.readInt() - if meta.colCount < 0 { - panic(fmt.Errorf("received negative column count: %d", meta.colCount)) - } - meta.actualColCount = meta.colCount - - if meta.flags&flagHasMorePages == flagHasMorePages { - meta.pagingState = copyBytes(f.readBytes()) - } - - if meta.flags&flagNoMetaData == flagNoMetaData { - return meta - } - - var keyspace, table string - globalSpec := meta.flags&flagGlobalTableSpec == flagGlobalTableSpec - if globalSpec { - keyspace = f.readString() - table = f.readString() - } - - var cols []ColumnInfo - if meta.colCount < 1000 { - // preallocate columninfo to avoid excess copying - cols = make([]ColumnInfo, meta.colCount) - for i := 0; i < meta.colCount; i++ { - f.readCol(&cols[i], &meta, globalSpec, keyspace, table) - } - - } else { - // use append, huge number of columns usually indicates a corrupt frame or - // just a huge row. - for i := 0; i < meta.colCount; i++ { - var col ColumnInfo - f.readCol(&col, &meta, globalSpec, keyspace, table) - cols = append(cols, col) - } - } - - meta.columns = cols - - return meta -} - -type resultVoidFrame struct { - frameHeader -} - -func (f *resultVoidFrame) String() string { - return "[result_void]" -} - -func (f *framer) parseResultFrame() (frame, error) { - kind := f.readInt() - - switch kind { - case resultKindVoid: - return &resultVoidFrame{frameHeader: *f.header}, nil - case resultKindRows: - return f.parseResultRows(), nil - case resultKindKeyspace: - return f.parseResultSetKeyspace(), nil - case resultKindPrepared: - return f.parseResultPrepared(), nil - case resultKindSchemaChanged: - return f.parseResultSchemaChange(), nil - } - - return nil, NewErrProtocol("unknown result kind: %x", kind) -} - -type resultRowsFrame struct { - frameHeader - - meta resultMetadata - // dont parse the rows here as we only need to do it once - numRows int -} - -func (f *resultRowsFrame) String() string { - return fmt.Sprintf("[result_rows meta=%v]", f.meta) -} - -func (f *framer) parseResultRows() frame { - result := &resultRowsFrame{} - result.meta = f.parseResultMetadata() - - result.numRows = f.readInt() - if result.numRows < 0 { - panic(fmt.Errorf("invalid row_count in result frame: %d", result.numRows)) - } - - return result -} - -type resultKeyspaceFrame struct { - frameHeader - keyspace string -} - -func (r *resultKeyspaceFrame) String() string { - return fmt.Sprintf("[result_keyspace keyspace=%s]", r.keyspace) -} - -func (f *framer) parseResultSetKeyspace() frame { - return &resultKeyspaceFrame{ - frameHeader: *f.header, - keyspace: f.readString(), - } -} - -type resultPreparedFrame struct { - frameHeader - - preparedID []byte - reqMeta preparedMetadata - respMeta resultMetadata -} - -func (f *framer) parseResultPrepared() frame { - frame := &resultPreparedFrame{ - frameHeader: *f.header, - preparedID: f.readShortBytes(), - reqMeta: f.parsePreparedMetadata(), - } - - if f.proto < protoVersion2 { - return frame - } - - frame.respMeta = f.parseResultMetadata() - - return frame -} - -type schemaChangeKeyspace struct { - frameHeader - - change string - keyspace string -} - -func (f schemaChangeKeyspace) String() string { - return fmt.Sprintf("[event schema_change_keyspace change=%q keyspace=%q]", f.change, f.keyspace) -} - -type schemaChangeTable struct { - frameHeader - - change string - keyspace string - object string -} - -func (f schemaChangeTable) String() string { - return fmt.Sprintf("[event schema_change change=%q keyspace=%q object=%q]", f.change, f.keyspace, f.object) -} - -type schemaChangeType struct { - frameHeader - - change string - keyspace string - object string -} - -type schemaChangeFunction struct { - frameHeader - - change string - keyspace string - name string - args []string -} - -type schemaChangeAggregate struct { - frameHeader - - change string - keyspace string - name string - args []string -} - -func (f *framer) parseResultSchemaChange() frame { - if f.proto <= protoVersion2 { - change := f.readString() - keyspace := f.readString() - table := f.readString() - - if table != "" { - return &schemaChangeTable{ - frameHeader: *f.header, - change: change, - keyspace: keyspace, - object: table, - } - } else { - return &schemaChangeKeyspace{ - frameHeader: *f.header, - change: change, - keyspace: keyspace, - } - } - } else { - change := f.readString() - target := f.readString() - - // TODO: could just use a separate type for each target - switch target { - case "KEYSPACE": - frame := &schemaChangeKeyspace{ - frameHeader: *f.header, - change: change, - } - - frame.keyspace = f.readString() - - return frame - case "TABLE": - frame := &schemaChangeTable{ - frameHeader: *f.header, - change: change, - } - - frame.keyspace = f.readString() - frame.object = f.readString() - - return frame - case "TYPE": - frame := &schemaChangeType{ - frameHeader: *f.header, - change: change, - } - - frame.keyspace = f.readString() - frame.object = f.readString() - - return frame - case "FUNCTION": - frame := &schemaChangeFunction{ - frameHeader: *f.header, - change: change, - } - - frame.keyspace = f.readString() - frame.name = f.readString() - frame.args = f.readStringList() - - return frame - case "AGGREGATE": - frame := &schemaChangeAggregate{ - frameHeader: *f.header, - change: change, - } - - frame.keyspace = f.readString() - frame.name = f.readString() - frame.args = f.readStringList() - - return frame - default: - panic(fmt.Errorf("gocql: unknown SCHEMA_CHANGE target: %q change: %q", target, change)) - } - } - -} - -type authenticateFrame struct { - frameHeader - - class string -} - -func (a *authenticateFrame) String() string { - return fmt.Sprintf("[authenticate class=%q]", a.class) -} - -func (f *framer) parseAuthenticateFrame() frame { - return &authenticateFrame{ - frameHeader: *f.header, - class: f.readString(), - } -} - -type authSuccessFrame struct { - frameHeader - - data []byte -} - -func (a *authSuccessFrame) String() string { - return fmt.Sprintf("[auth_success data=%q]", a.data) -} - -func (f *framer) parseAuthSuccessFrame() frame { - return &authSuccessFrame{ - frameHeader: *f.header, - data: f.readBytes(), - } -} - -type authChallengeFrame struct { - frameHeader - - data []byte -} - -func (a *authChallengeFrame) String() string { - return fmt.Sprintf("[auth_challenge data=%q]", a.data) -} - -func (f *framer) parseAuthChallengeFrame() frame { - return &authChallengeFrame{ - frameHeader: *f.header, - data: f.readBytes(), - } -} - -type statusChangeEventFrame struct { - frameHeader - - change string - host net.IP - port int -} - -func (t statusChangeEventFrame) String() string { - return fmt.Sprintf("[status_change change=%s host=%v port=%v]", t.change, t.host, t.port) -} - -// essentially the same as statusChange -type topologyChangeEventFrame struct { - frameHeader - - change string - host net.IP - port int -} - -func (t topologyChangeEventFrame) String() string { - return fmt.Sprintf("[topology_change change=%s host=%v port=%v]", t.change, t.host, t.port) -} - -func (f *framer) parseEventFrame() frame { - eventType := f.readString() - - switch eventType { - case "TOPOLOGY_CHANGE": - frame := &topologyChangeEventFrame{frameHeader: *f.header} - frame.change = f.readString() - frame.host, frame.port = f.readInet() - - return frame - case "STATUS_CHANGE": - frame := &statusChangeEventFrame{frameHeader: *f.header} - frame.change = f.readString() - frame.host, frame.port = f.readInet() - - return frame - case "SCHEMA_CHANGE": - // this should work for all versions - return f.parseResultSchemaChange() - default: - panic(fmt.Errorf("gocql: unknown event type: %q", eventType)) - } - -} - -type writeAuthResponseFrame struct { - data []byte -} - -func (a *writeAuthResponseFrame) String() string { - return fmt.Sprintf("[auth_response data=%q]", a.data) -} - -func (a *writeAuthResponseFrame) buildFrame(framer *framer, streamID int) error { - return framer.writeAuthResponseFrame(streamID, a.data) -} - -func (f *framer) writeAuthResponseFrame(streamID int, data []byte) error { - f.writeHeader(f.flags, opAuthResponse, streamID) - f.writeBytes(data) - return f.finish() -} - -type queryValues struct { - value []byte - - // optional name, will set With names for values flag - name string - isUnset bool -} - -type queryParams struct { - consistency Consistency - // v2+ - skipMeta bool - values []queryValues - pageSize int - pagingState []byte - serialConsistency SerialConsistency - // v3+ - defaultTimestamp bool - defaultTimestampValue int64 - // v5+ - keyspace string -} - -func (q queryParams) String() string { - return fmt.Sprintf("[query_params consistency=%v skip_meta=%v page_size=%d paging_state=%q serial_consistency=%v default_timestamp=%v values=%v keyspace=%s]", - q.consistency, q.skipMeta, q.pageSize, q.pagingState, q.serialConsistency, q.defaultTimestamp, q.values, q.keyspace) -} - -func (f *framer) writeQueryParams(opts *queryParams) { - f.writeConsistency(opts.consistency) - - if f.proto == protoVersion1 { - return - } - - var flags byte - if len(opts.values) > 0 { - flags |= flagValues - } - if opts.skipMeta { - flags |= flagSkipMetaData - } - if opts.pageSize > 0 { - flags |= flagPageSize - } - if len(opts.pagingState) > 0 { - flags |= flagWithPagingState - } - if opts.serialConsistency > 0 { - flags |= flagWithSerialConsistency - } - - names := false - - // protoV3 specific things - if f.proto > protoVersion2 { - if opts.defaultTimestamp { - flags |= flagDefaultTimestamp - } - - if len(opts.values) > 0 && opts.values[0].name != "" { - flags |= flagWithNameValues - names = true - } - } - - if opts.keyspace != "" { - if f.proto > protoVersion4 { - flags |= flagWithKeyspace - } else { - panic(fmt.Errorf("the keyspace can only be set with protocol 5 or higher")) - } - } - - if f.proto > protoVersion4 { - f.writeUint(uint32(flags)) - } else { - f.writeByte(flags) - } - - if n := len(opts.values); n > 0 { - f.writeShort(uint16(n)) - - for i := 0; i < n; i++ { - if names { - f.writeString(opts.values[i].name) - } - if opts.values[i].isUnset { - f.writeUnset() - } else { - f.writeBytes(opts.values[i].value) - } - } - } - - if opts.pageSize > 0 { - f.writeInt(int32(opts.pageSize)) - } - - if len(opts.pagingState) > 0 { - f.writeBytes(opts.pagingState) - } - - if opts.serialConsistency > 0 { - f.writeConsistency(Consistency(opts.serialConsistency)) - } - - if f.proto > protoVersion2 && opts.defaultTimestamp { - // timestamp in microseconds - var ts int64 - if opts.defaultTimestampValue != 0 { - ts = opts.defaultTimestampValue - } else { - ts = time.Now().UnixNano() / 1000 - } - f.writeLong(ts) - } - - if opts.keyspace != "" { - f.writeString(opts.keyspace) - } -} - -type writeQueryFrame struct { - statement string - params queryParams - - // v4+ - customPayload map[string][]byte -} - -func (w *writeQueryFrame) String() string { - return fmt.Sprintf("[query statement=%q params=%v]", w.statement, w.params) -} - -func (w *writeQueryFrame) buildFrame(framer *framer, streamID int) error { - return framer.writeQueryFrame(streamID, w.statement, &w.params, w.customPayload) -} - -func (f *framer) writeQueryFrame(streamID int, statement string, params *queryParams, customPayload map[string][]byte) error { - if len(customPayload) > 0 { - f.payload() - } - f.writeHeader(f.flags, opQuery, streamID) - f.writeCustomPayload(&customPayload) - f.writeLongString(statement) - f.writeQueryParams(params) - - return f.finish() -} - -type frameBuilder interface { - buildFrame(framer *framer, streamID int) error -} - -type frameWriterFunc func(framer *framer, streamID int) error - -func (f frameWriterFunc) buildFrame(framer *framer, streamID int) error { - return f(framer, streamID) -} - -type writeExecuteFrame struct { - preparedID []byte - params queryParams - - // v4+ - customPayload map[string][]byte -} - -func (e *writeExecuteFrame) String() string { - return fmt.Sprintf("[execute id=% X params=%v]", e.preparedID, &e.params) -} - -func (e *writeExecuteFrame) buildFrame(fr *framer, streamID int) error { - return fr.writeExecuteFrame(streamID, e.preparedID, &e.params, &e.customPayload) -} - -func (f *framer) writeExecuteFrame(streamID int, preparedID []byte, params *queryParams, customPayload *map[string][]byte) error { - if len(*customPayload) > 0 { - f.payload() - } - f.writeHeader(f.flags, opExecute, streamID) - f.writeCustomPayload(customPayload) - f.writeShortBytes(preparedID) - if f.proto > protoVersion1 { - f.writeQueryParams(params) - } else { - n := len(params.values) - f.writeShort(uint16(n)) - for i := 0; i < n; i++ { - if params.values[i].isUnset { - f.writeUnset() - } else { - f.writeBytes(params.values[i].value) - } - } - f.writeConsistency(params.consistency) - } - - return f.finish() -} - -// TODO: can we replace BatchStatemt with batchStatement? As they prety much -// duplicate each other -type batchStatment struct { - preparedID []byte - statement string - values []queryValues -} - -type writeBatchFrame struct { - typ BatchType - statements []batchStatment - consistency Consistency - - // v3+ - serialConsistency SerialConsistency - defaultTimestamp bool - defaultTimestampValue int64 - - //v4+ - customPayload map[string][]byte -} - -func (w *writeBatchFrame) buildFrame(framer *framer, streamID int) error { - return framer.writeBatchFrame(streamID, w, w.customPayload) -} - -func (f *framer) writeBatchFrame(streamID int, w *writeBatchFrame, customPayload map[string][]byte) error { - if len(customPayload) > 0 { - f.payload() - } - f.writeHeader(f.flags, opBatch, streamID) - f.writeCustomPayload(&customPayload) - f.writeByte(byte(w.typ)) - - n := len(w.statements) - f.writeShort(uint16(n)) - - var flags byte - - for i := 0; i < n; i++ { - b := &w.statements[i] - if len(b.preparedID) == 0 { - f.writeByte(0) - f.writeLongString(b.statement) - } else { - f.writeByte(1) - f.writeShortBytes(b.preparedID) - } - - f.writeShort(uint16(len(b.values))) - for j := range b.values { - col := b.values[j] - if f.proto > protoVersion2 && col.name != "" { - // TODO: move this check into the caller and set a flag on writeBatchFrame - // to indicate using named values - if f.proto <= protoVersion5 { - return fmt.Errorf("gocql: named query values are not supported in batches, please see https://issues.apache.org/jira/browse/CASSANDRA-10246") - } - flags |= flagWithNameValues - f.writeString(col.name) - } - if col.isUnset { - f.writeUnset() - } else { - f.writeBytes(col.value) - } - } - } - - f.writeConsistency(w.consistency) - - if f.proto > protoVersion2 { - if w.serialConsistency > 0 { - flags |= flagWithSerialConsistency - } - if w.defaultTimestamp { - flags |= flagDefaultTimestamp - } - - if f.proto > protoVersion4 { - f.writeUint(uint32(flags)) - } else { - f.writeByte(flags) - } - - if w.serialConsistency > 0 { - f.writeConsistency(Consistency(w.serialConsistency)) - } - - if w.defaultTimestamp { - var ts int64 - if w.defaultTimestampValue != 0 { - ts = w.defaultTimestampValue - } else { - ts = time.Now().UnixNano() / 1000 - } - f.writeLong(ts) - } - } - - return f.finish() -} - -type writeOptionsFrame struct{} - -func (w *writeOptionsFrame) buildFrame(framer *framer, streamID int) error { - return framer.writeOptionsFrame(streamID, w) -} - -func (f *framer) writeOptionsFrame(stream int, _ *writeOptionsFrame) error { - f.writeHeader(f.flags&^flagCompress, opOptions, stream) - return f.finish() -} - -type writeRegisterFrame struct { - events []string -} - -func (w *writeRegisterFrame) buildFrame(framer *framer, streamID int) error { - return framer.writeRegisterFrame(streamID, w) -} - -func (f *framer) writeRegisterFrame(streamID int, w *writeRegisterFrame) error { - f.writeHeader(f.flags, opRegister, streamID) - f.writeStringList(w.events) - - return f.finish() -} - -func (f *framer) readByte() byte { - if len(f.buf) < 1 { - panic(fmt.Errorf("not enough bytes in buffer to read byte require 1 got: %d", len(f.buf))) - } - - b := f.buf[0] - f.buf = f.buf[1:] - return b -} - -func (f *framer) readInt() (n int) { - if len(f.buf) < 4 { - panic(fmt.Errorf("not enough bytes in buffer to read int require 4 got: %d", len(f.buf))) - } - - n = int(int32(f.buf[0])<<24 | int32(f.buf[1])<<16 | int32(f.buf[2])<<8 | int32(f.buf[3])) - f.buf = f.buf[4:] - return -} - -func (f *framer) readShort() (n uint16) { - if len(f.buf) < 2 { - panic(fmt.Errorf("not enough bytes in buffer to read short require 2 got: %d", len(f.buf))) - } - n = uint16(f.buf[0])<<8 | uint16(f.buf[1]) - f.buf = f.buf[2:] - return -} - -func (f *framer) readString() (s string) { - size := f.readShort() - - if len(f.buf) < int(size) { - panic(fmt.Errorf("not enough bytes in buffer to read string require %d got: %d", size, len(f.buf))) - } - - s = string(f.buf[:size]) - f.buf = f.buf[size:] - return -} - -func (f *framer) readLongString() (s string) { - size := f.readInt() - - if len(f.buf) < size { - panic(fmt.Errorf("not enough bytes in buffer to read long string require %d got: %d", size, len(f.buf))) - } - - s = string(f.buf[:size]) - f.buf = f.buf[size:] - return -} - -func (f *framer) readUUID() *UUID { - if len(f.buf) < 16 { - panic(fmt.Errorf("not enough bytes in buffer to read uuid require %d got: %d", 16, len(f.buf))) - } - - // TODO: how to handle this error, if it is a uuid, then sureley, problems? - u, _ := UUIDFromBytes(f.buf[:16]) - f.buf = f.buf[16:] - return &u -} - -func (f *framer) readStringList() []string { - size := f.readShort() - - l := make([]string, size) - for i := 0; i < int(size); i++ { - l[i] = f.readString() - } - - return l -} - -func (f *framer) readBytesInternal() ([]byte, error) { - size := f.readInt() - if size < 0 { - return nil, nil - } - - if len(f.buf) < size { - return nil, fmt.Errorf("not enough bytes in buffer to read bytes require %d got: %d", size, len(f.buf)) - } - - l := f.buf[:size] - f.buf = f.buf[size:] - - return l, nil -} - -func (f *framer) readBytes() []byte { - l, err := f.readBytesInternal() - if err != nil { - panic(err) - } - - return l -} - -func (f *framer) readShortBytes() []byte { - size := f.readShort() - if len(f.buf) < int(size) { - panic(fmt.Errorf("not enough bytes in buffer to read short bytes: require %d got %d", size, len(f.buf))) - } - - l := f.buf[:size] - f.buf = f.buf[size:] - - return l -} - -func (f *framer) readInetAdressOnly() net.IP { - if len(f.buf) < 1 { - panic(fmt.Errorf("not enough bytes in buffer to read inet size require %d got: %d", 1, len(f.buf))) - } - - size := f.buf[0] - f.buf = f.buf[1:] - - if !(size == 4 || size == 16) { - panic(fmt.Errorf("invalid IP size: %d", size)) - } - - if len(f.buf) < 1 { - panic(fmt.Errorf("not enough bytes in buffer to read inet require %d got: %d", size, len(f.buf))) - } - - ip := make([]byte, size) - copy(ip, f.buf[:size]) - f.buf = f.buf[size:] - return net.IP(ip) -} - -func (f *framer) readInet() (net.IP, int) { - return f.readInetAdressOnly(), f.readInt() -} - -func (f *framer) readConsistency() Consistency { - return Consistency(f.readShort()) -} - -func (f *framer) readBytesMap() map[string][]byte { - size := f.readShort() - m := make(map[string][]byte, size) - - for i := 0; i < int(size); i++ { - k := f.readString() - v := f.readBytes() - m[k] = v - } - - return m -} - -func (f *framer) readStringMultiMap() map[string][]string { - size := f.readShort() - m := make(map[string][]string, size) - - for i := 0; i < int(size); i++ { - k := f.readString() - v := f.readStringList() - m[k] = v - } - - return m -} - -func (f *framer) writeByte(b byte) { - f.buf = append(f.buf, b) -} - -func appendBytes(p []byte, d []byte) []byte { - if d == nil { - return appendInt(p, -1) - } - p = appendInt(p, int32(len(d))) - p = append(p, d...) - return p -} - -func appendShort(p []byte, n uint16) []byte { - return append(p, - byte(n>>8), - byte(n), - ) -} - -func appendInt(p []byte, n int32) []byte { - return append(p, byte(n>>24), - byte(n>>16), - byte(n>>8), - byte(n)) -} - -func appendUint(p []byte, n uint32) []byte { - return append(p, byte(n>>24), - byte(n>>16), - byte(n>>8), - byte(n)) -} - -func appendLong(p []byte, n int64) []byte { - return append(p, - byte(n>>56), - byte(n>>48), - byte(n>>40), - byte(n>>32), - byte(n>>24), - byte(n>>16), - byte(n>>8), - byte(n), - ) -} - -func (f *framer) writeCustomPayload(customPayload *map[string][]byte) { - if len(*customPayload) > 0 { - if f.proto < protoVersion4 { - panic("Custom payload is not supported with version V3 or less") - } - f.writeBytesMap(*customPayload) - } -} - -// these are protocol level binary types -func (f *framer) writeInt(n int32) { - f.buf = appendInt(f.buf, n) -} - -func (f *framer) writeUint(n uint32) { - f.buf = appendUint(f.buf, n) -} - -func (f *framer) writeShort(n uint16) { - f.buf = appendShort(f.buf, n) -} - -func (f *framer) writeLong(n int64) { - f.buf = appendLong(f.buf, n) -} - -func (f *framer) writeString(s string) { - f.writeShort(uint16(len(s))) - f.buf = append(f.buf, s...) -} - -func (f *framer) writeLongString(s string) { - f.writeInt(int32(len(s))) - f.buf = append(f.buf, s...) -} - -func (f *framer) writeStringList(l []string) { - f.writeShort(uint16(len(l))) - for _, s := range l { - f.writeString(s) - } -} - -func (f *framer) writeUnset() { - // Protocol version 4 specifies that bind variables do not require having a - // value when executing a statement. Bind variables without a value are - // called 'unset'. The 'unset' bind variable is serialized as the int - // value '-2' without following bytes. - f.writeInt(-2) -} - -func (f *framer) writeBytes(p []byte) { - // TODO: handle null case correctly, - // [bytes] A [int] n, followed by n bytes if n >= 0. If n < 0, - // no byte should follow and the value represented is `null`. - if p == nil { - f.writeInt(-1) - } else { - f.writeInt(int32(len(p))) - f.buf = append(f.buf, p...) - } -} - -func (f *framer) writeShortBytes(p []byte) { - f.writeShort(uint16(len(p))) - f.buf = append(f.buf, p...) -} - -func (f *framer) writeConsistency(cons Consistency) { - f.writeShort(uint16(cons)) -} - -func (f *framer) writeStringMap(m map[string]string) { - f.writeShort(uint16(len(m))) - for k, v := range m { - f.writeString(k) - f.writeString(v) - } -} - -func (f *framer) writeBytesMap(m map[string][]byte) { - f.writeShort(uint16(len(m))) - for k, v := range m { - f.writeString(k) - f.writeBytes(v) - } -} diff --git a/frame_test.go b/frame_test.go index 170cba710..50b9b0277 100644 --- a/frame_test.go +++ b/frame_test.go @@ -60,18 +60,18 @@ func TestFuzzBugs(t *testing.T) { t.Logf("test %d input: %q", i, test) r := bytes.NewReader(test) - head, err := readHeader(r, make([]byte, 9)) + head, err := protocol.ReadHeader(r, make([]byte, 9)) if err != nil { continue } - framer := newFramer(nil, byte(head.version)) + framer := protocol.NewFramer(nil, byte(head.version)) err = framer.readFrame(r, &head) if err != nil { continue } - frame, err := framer.parseFrame() + frame, err := framer.ParseFrame() if err != nil { continue } @@ -86,7 +86,7 @@ func TestFrameWriteTooLong(t *testing.T) { t.Skip("skipping test in travis due to memory pressure with the race detecor") } - framer := newFramer(nil, 2) + framer := protocol.NewFramer(nil, 2) framer.writeHeader(0, opStartup, 1) framer.writeBytes(make([]byte, maxFrameSize+1)) @@ -106,7 +106,7 @@ func TestFrameReadTooLong(t *testing.T) { // write a new header right after this frame to verify that we can read it r.Write([]byte{0x02, 0x00, 0x00, byte(opReady), 0x00, 0x00, 0x00, 0x00}) - framer := newFramer(nil, 2) + framer := protocol.NewFramer(nil, 2) head := frameHeader{ version: 2, @@ -119,7 +119,7 @@ func TestFrameReadTooLong(t *testing.T) { t.Fatalf("expected to get %v got %v", ErrFrameTooBig, err) } - head, err = readHeader(r, make([]byte, 8)) + head, err = protocol.ReadHeader(r, make([]byte, 8)) if err != nil { t.Fatal(err) } diff --git a/framer_bench_test.go b/framer_bench_test.go index bce3742c2..83716c56e 100644 --- a/framer_bench_test.go +++ b/framer_bench_test.go @@ -64,7 +64,7 @@ func BenchmarkParseRowsFrame(b *testing.B) { buf: data, } - _, err = framer.parseFrame() + _, err = framer.ParseFrame() if err != nil { b.Fatal(err) } diff --git a/fuzz.go b/fuzz.go index 58dd3b691..284e0a292 100644 --- a/fuzz.go +++ b/fuzz.go @@ -34,18 +34,18 @@ func Fuzz(data []byte) int { r := bytes.NewReader(data) - head, err := readHeader(r, make([]byte, 9)) + head, err := protocol.ReadHeader(r, make([]byte, 9)) if err != nil { return 0 } - framer := newFramer(r, &bw, nil, byte(head.version)) - err = framer.readFrame(&head) + framer := protocol.NewFramer(r, &bw, nil, byte(head.version)) + err = framer.ReadFrame(&head) if err != nil { return 0 } - frame, err := framer.parseFrame() + frame, err := framer.ParseFrame() if err != nil { return 0 } diff --git a/helpers.go b/helpers.go index f2faee9e0..62ee40b93 100644 --- a/helpers.go +++ b/helpers.go @@ -26,13 +26,8 @@ package gocql import ( "fmt" - "math/big" "net" "reflect" - "strings" - "time" - - "gopkg.in/inf.v0" ) type RowData struct { @@ -40,268 +35,10 @@ type RowData struct { Values []interface{} } -func goType(t TypeInfo) (reflect.Type, error) { - switch t.Type() { - case TypeVarchar, TypeAscii, TypeInet, TypeText: - return reflect.TypeOf(*new(string)), nil - case TypeBigInt, TypeCounter: - return reflect.TypeOf(*new(int64)), nil - case TypeTime: - return reflect.TypeOf(*new(time.Duration)), nil - case TypeTimestamp: - return reflect.TypeOf(*new(time.Time)), nil - case TypeBlob: - return reflect.TypeOf(*new([]byte)), nil - case TypeBoolean: - return reflect.TypeOf(*new(bool)), nil - case TypeFloat: - return reflect.TypeOf(*new(float32)), nil - case TypeDouble: - return reflect.TypeOf(*new(float64)), nil - case TypeInt: - return reflect.TypeOf(*new(int)), nil - case TypeSmallInt: - return reflect.TypeOf(*new(int16)), nil - case TypeTinyInt: - return reflect.TypeOf(*new(int8)), nil - case TypeDecimal: - return reflect.TypeOf(*new(*inf.Dec)), nil - case TypeUUID, TypeTimeUUID: - return reflect.TypeOf(*new(UUID)), nil - case TypeList, TypeSet: - elemType, err := goType(t.(CollectionType).Elem) - if err != nil { - return nil, err - } - return reflect.SliceOf(elemType), nil - case TypeMap: - keyType, err := goType(t.(CollectionType).Key) - if err != nil { - return nil, err - } - valueType, err := goType(t.(CollectionType).Elem) - if err != nil { - return nil, err - } - return reflect.MapOf(keyType, valueType), nil - case TypeVarint: - return reflect.TypeOf(*new(*big.Int)), nil - case TypeTuple: - // what can we do here? all there is to do is to make a list of interface{} - tuple := t.(TupleTypeInfo) - return reflect.TypeOf(make([]interface{}, len(tuple.Elems))), nil - case TypeUDT: - return reflect.TypeOf(make(map[string]interface{})), nil - case TypeDate: - return reflect.TypeOf(*new(time.Time)), nil - case TypeDuration: - return reflect.TypeOf(*new(Duration)), nil - default: - return nil, fmt.Errorf("cannot create Go type for unknown CQL type %s", t) - } -} - func dereference(i interface{}) interface{} { return reflect.Indirect(reflect.ValueOf(i)).Interface() } -func getCassandraBaseType(name string) Type { - switch name { - case "ascii": - return TypeAscii - case "bigint": - return TypeBigInt - case "blob": - return TypeBlob - case "boolean": - return TypeBoolean - case "counter": - return TypeCounter - case "date": - return TypeDate - case "decimal": - return TypeDecimal - case "double": - return TypeDouble - case "duration": - return TypeDuration - case "float": - return TypeFloat - case "int": - return TypeInt - case "smallint": - return TypeSmallInt - case "tinyint": - return TypeTinyInt - case "time": - return TypeTime - case "timestamp": - return TypeTimestamp - case "uuid": - return TypeUUID - case "varchar": - return TypeVarchar - case "text": - return TypeText - case "varint": - return TypeVarint - case "timeuuid": - return TypeTimeUUID - case "inet": - return TypeInet - case "MapType": - return TypeMap - case "ListType": - return TypeList - case "SetType": - return TypeSet - case "TupleType": - return TypeTuple - default: - return TypeCustom - } -} - -func getCassandraType(name string, logger StdLogger) TypeInfo { - if strings.HasPrefix(name, "frozen<") { - return getCassandraType(strings.TrimPrefix(name[:len(name)-1], "frozen<"), logger) - } else if strings.HasPrefix(name, "set<") { - return CollectionType{ - NativeType: NativeType{typ: TypeSet}, - Elem: getCassandraType(strings.TrimPrefix(name[:len(name)-1], "set<"), logger), - } - } else if strings.HasPrefix(name, "list<") { - return CollectionType{ - NativeType: NativeType{typ: TypeList}, - Elem: getCassandraType(strings.TrimPrefix(name[:len(name)-1], "list<"), logger), - } - } else if strings.HasPrefix(name, "map<") { - names := splitCompositeTypes(strings.TrimPrefix(name[:len(name)-1], "map<")) - if len(names) != 2 { - logger.Printf("Error parsing map type, it has %d subelements, expecting 2\n", len(names)) - return NativeType{ - typ: TypeCustom, - } - } - return CollectionType{ - NativeType: NativeType{typ: TypeMap}, - Key: getCassandraType(names[0], logger), - Elem: getCassandraType(names[1], logger), - } - } else if strings.HasPrefix(name, "tuple<") { - names := splitCompositeTypes(strings.TrimPrefix(name[:len(name)-1], "tuple<")) - types := make([]TypeInfo, len(names)) - - for i, name := range names { - types[i] = getCassandraType(name, logger) - } - - return TupleTypeInfo{ - NativeType: NativeType{typ: TypeTuple}, - Elems: types, - } - } else { - return NativeType{ - typ: getCassandraBaseType(name), - } - } -} - -func splitCompositeTypes(name string) []string { - if !strings.Contains(name, "<") { - return strings.Split(name, ", ") - } - var parts []string - lessCount := 0 - segment := "" - for _, char := range name { - if char == ',' && lessCount == 0 { - if segment != "" { - parts = append(parts, strings.TrimSpace(segment)) - } - segment = "" - continue - } - segment += string(char) - if char == '<' { - lessCount++ - } else if char == '>' { - lessCount-- - } - } - if segment != "" { - parts = append(parts, strings.TrimSpace(segment)) - } - return parts -} - -func apacheToCassandraType(t string) string { - t = strings.Replace(t, apacheCassandraTypePrefix, "", -1) - t = strings.Replace(t, "(", "<", -1) - t = strings.Replace(t, ")", ">", -1) - types := strings.FieldsFunc(t, func(r rune) bool { - return r == '<' || r == '>' || r == ',' - }) - for _, typ := range types { - t = strings.Replace(t, typ, getApacheCassandraType(typ).String(), -1) - } - // This is done so it exactly matches what Cassandra returns - return strings.Replace(t, ",", ", ", -1) -} - -func getApacheCassandraType(class string) Type { - switch strings.TrimPrefix(class, apacheCassandraTypePrefix) { - case "AsciiType": - return TypeAscii - case "LongType": - return TypeBigInt - case "BytesType": - return TypeBlob - case "BooleanType": - return TypeBoolean - case "CounterColumnType": - return TypeCounter - case "DecimalType": - return TypeDecimal - case "DoubleType": - return TypeDouble - case "FloatType": - return TypeFloat - case "Int32Type": - return TypeInt - case "ShortType": - return TypeSmallInt - case "ByteType": - return TypeTinyInt - case "TimeType": - return TypeTime - case "DateType", "TimestampType": - return TypeTimestamp - case "UUIDType", "LexicalUUIDType": - return TypeUUID - case "UTF8Type": - return TypeVarchar - case "IntegerType": - return TypeVarint - case "TimeUUIDType": - return TypeTimeUUID - case "InetAddressType": - return TypeInet - case "MapType": - return TypeMap - case "ListType": - return TypeList - case "SetType": - return TypeSet - case "TupleType": - return TypeTuple - case "DurationType": - return TypeDuration - default: - return TypeCustom - } -} - func (r *RowData) rowMap(m map[string]interface{}) { for i, column := range r.Columns { val := dereference(r.Values[i]) @@ -392,44 +129,6 @@ func (iter *Iter) SliceMap() ([]map[string]interface{}, error) { return dataToReturn, nil } -// MapScan takes a map[string]interface{} and populates it with a row -// that is returned from cassandra. -// -// Each call to MapScan() must be called with a new map object. -// During the call to MapScan() any pointers in the existing map -// are replaced with non pointer types before the call returns -// -// iter := session.Query(`SELECT * FROM mytable`).Iter() -// for { -// // New map each iteration -// row := make(map[string]interface{}) -// if !iter.MapScan(row) { -// break -// } -// // Do things with row -// if fullname, ok := row["fullname"]; ok { -// fmt.Printf("Full Name: %s\n", fullname) -// } -// } -// -// You can also pass pointers in the map before each call -// -// var fullName FullName // Implements gocql.Unmarshaler and gocql.Marshaler interfaces -// var address net.IP -// var age int -// iter := session.Query(`SELECT * FROM scan_map_table`).Iter() -// for { -// // New map each iteration -// row := map[string]interface{}{ -// "fullname": &fullName, -// "age": &age, -// "address": &address, -// } -// if !iter.MapScan(row) { -// break -// } -// fmt.Printf("First: %s Age: %d Address: %q\n", fullName.FirstName, age, address) -// } func (iter *Iter) MapScan(m map[string]interface{}) bool { if iter.err != nil { return false @@ -451,12 +150,6 @@ func (iter *Iter) MapScan(m map[string]interface{}) bool { return false } -func copyBytes(p []byte) []byte { - b := make([]byte, len(p)) - copy(b, p) - return b -} - var failDNS = false func LookupIP(host string) ([]net.IP, error) { diff --git a/helpers_test.go b/helpers_test.go index 67922ba5d..5a05273f2 100644 --- a/helpers_test.go +++ b/helpers_test.go @@ -30,7 +30,7 @@ import ( ) func TestGetCassandraType_Set(t *testing.T) { - typ := getCassandraType("set", &defaultLogger{}) + typ := protocol.GetCassandraType("set", &defaultLogger{}) set, ok := typ.(CollectionType) if !ok { t.Fatalf("expected CollectionType got %T", typ) @@ -227,7 +227,7 @@ func TestGetCassandraType(t *testing.T) { for _, test := range tests { t.Run(test.input, func(t *testing.T) { - got := getCassandraType(test.input, &defaultLogger{}) + got := protocol.GetCassandraType(test.input, &defaultLogger{}) // TODO(zariel): define an equal method on the types? if !reflect.DeepEqual(got, test.exp) { diff --git a/host_source.go b/host_source.go index a0bab9ad0..340c344fc 100644 --- a/host_source.go +++ b/host_source.go @@ -457,7 +457,7 @@ type ringDescriber struct { func checkSystemSchema(control *controlConn) (bool, error) { iter := control.query("SELECT * FROM system_schema.keyspaces") if err := iter.err; err != nil { - if errf, ok := err.(*errorFrame); ok { + if errf, ok := err.(*internal_errors.ErrorFrame); ok { if errf.code == ErrCodeSyntax { return false, nil } diff --git a/internal/ccm/ccm.go b/internal/ccm/ccm.go index 55b540158..27ba9b26c 100644 --- a/internal/ccm/ccm.go +++ b/internal/ccm/ccm.go @@ -179,7 +179,7 @@ func Status() (map[string]Host, error) { line := strings.Split(strings.TrimSpace(text), "=") k, v := line[0], line[1] if k == "binary" { - // could check errors + // could check internal_errors // ('127.0.0.1', 9042) v = v[2:] // ('' if i := strings.IndexByte(v, '\''); i < 0 { diff --git a/internal/compressor/compressor.go b/internal/compressor/compressor.go new file mode 100644 index 000000000..f54906899 --- /dev/null +++ b/internal/compressor/compressor.go @@ -0,0 +1,26 @@ +package compressor + +import "github.com/golang/snappy" + +type Compressor interface { + Name() string + Encode(data []byte) ([]byte, error) + Decode(data []byte) ([]byte, error) +} + +// SnappyCompressor implements the Compressor interface and can be used to +// compress incoming and outgoing frames. The snappy compression algorithm +// aims for very high speeds and reasonable compression. +type SnappyCompressor struct{} + +func (s SnappyCompressor) Name() string { + return "snappy" +} + +func (s SnappyCompressor) Encode(data []byte) ([]byte, error) { + return snappy.Encode(nil, data), nil +} + +func (s SnappyCompressor) Decode(data []byte) ([]byte, error) { + return snappy.Decode(nil, data) +} diff --git a/internal/internal_errors/errors.go b/internal/internal_errors/errors.go new file mode 100644 index 000000000..7095ceb4a --- /dev/null +++ b/internal/internal_errors/errors.go @@ -0,0 +1,201 @@ +package internal_errors + +import ( + "fmt" + "github.com/gocql/gocql/internal/protocol" +) + +// See CQL Binary Protocol v5, section 8 for more details. +// https://github.com/apache/cassandra/blob/7337fc0/doc/native_protocol_v5.spec +const ( + // ErrCodeServer indicates unexpected error on server-side. + // + // See https://github.com/apache/cassandra/blob/7337fc0/doc/native_protocol_v5.spec#L1246-L1247 + ErrCodeServer = 0x0000 + // ErrCodeProtocol indicates a protocol violation by some client message. + // + // See https://github.com/apache/cassandra/blob/7337fc0/doc/native_protocol_v5.spec#L1248-L1250 + ErrCodeProtocol = 0x000A + // ErrCodeCredentials indicates missing required authentication. + // + // See https://github.com/apache/cassandra/blob/7337fc0/doc/native_protocol_v5.spec#L1251-L1254 + ErrCodeCredentials = 0x0100 + // ErrCodeUnavailable indicates unavailable error. + // + // See https://github.com/apache/cassandra/blob/7337fc0/doc/native_protocol_v5.spec#L1255-L1265 + ErrCodeUnavailable = 0x1000 + // ErrCodeOverloaded returned in case of request on overloaded node coordinator. + // + // See https://github.com/apache/cassandra/blob/7337fc0/doc/native_protocol_v5.spec#L1266-L1267 + ErrCodeOverloaded = 0x1001 + // ErrCodeBootstrapping returned from the coordinator node in bootstrapping phase. + // + // See https://github.com/apache/cassandra/blob/7337fc0/doc/native_protocol_v5.spec#L1268-L1269 + ErrCodeBootstrapping = 0x1002 + // ErrCodeTruncate indicates truncation exception. + // + // See https://github.com/apache/cassandra/blob/7337fc0/doc/native_protocol_v5.spec#L1270 + ErrCodeTruncate = 0x1003 + // ErrCodeWriteTimeout returned in case of timeout during the request write. + // + // See https://github.com/apache/cassandra/blob/7337fc0/doc/native_protocol_v5.spec#L1271-L1304 + ErrCodeWriteTimeout = 0x1100 + // ErrCodeReadTimeout returned in case of timeout during the request read. + // + // See https://github.com/apache/cassandra/blob/7337fc0/doc/native_protocol_v5.spec#L1305-L1321 + ErrCodeReadTimeout = 0x1200 + // ErrCodeReadFailure indicates request read error which is not covered by ErrCodeReadTimeout. + // + // See https://github.com/apache/cassandra/blob/7337fc0/doc/native_protocol_v5.spec#L1322-L1340 + ErrCodeReadFailure = 0x1300 + // ErrCodeFunctionFailure indicates an error in user-defined function. + // + // See https://github.com/apache/cassandra/blob/7337fc0/doc/native_protocol_v5.spec#L1341-L1347 + ErrCodeFunctionFailure = 0x1400 + // ErrCodeWriteFailure indicates request write error which is not covered by ErrCodeWriteTimeout. + // + // See https://github.com/apache/cassandra/blob/7337fc0/doc/native_protocol_v5.spec#L1348-L1385 + ErrCodeWriteFailure = 0x1500 + // ErrCodeCDCWriteFailure is defined, but not yet documented in CQLv5 protocol. + // + // See https://github.com/apache/cassandra/blob/7337fc0/doc/native_protocol_v5.spec#L1386 + ErrCodeCDCWriteFailure = 0x1600 + // ErrCodeCASWriteUnknown indicates only partially completed CAS operation. + // + // See https://github.com/apache/cassandra/blob/7337fc0/doc/native_protocol_v5.spec#L1387-L1397 + ErrCodeCASWriteUnknown = 0x1700 + // ErrCodeSyntax indicates the syntax error in the query. + // + // See https://github.com/apache/cassandra/blob/7337fc0/doc/native_protocol_v5.spec#L1399 + ErrCodeSyntax = 0x2000 + // ErrCodeUnauthorized indicates access rights violation by user on performed operation. + // + // See https://github.com/apache/cassandra/blob/7337fc0/doc/native_protocol_v5.spec#L1400-L1401 + ErrCodeUnauthorized = 0x2100 + // ErrCodeInvalid indicates invalid query error which is not covered by ErrCodeSyntax. + // + // See https://github.com/apache/cassandra/blob/7337fc0/doc/native_protocol_v5.spec#L1402 + ErrCodeInvalid = 0x2200 + // ErrCodeConfig indicates the configuration error. + // + // See https://github.com/apache/cassandra/blob/7337fc0/doc/native_protocol_v5.spec#L1403 + ErrCodeConfig = 0x2300 + // ErrCodeAlreadyExists is returned for the requests creating the existing keyspace/table. + // + // See https://github.com/apache/cassandra/blob/7337fc0/doc/native_protocol_v5.spec#L1404-L1413 + ErrCodeAlreadyExists = 0x2400 + // ErrCodeUnprepared returned from the host for prepared statement which is unknown. + // + // See https://github.com/apache/cassandra/blob/7337fc0/doc/native_protocol_v5.spec#L1414-L1417 + ErrCodeUnprepared = 0x2500 +) + +type RequestError interface { + Code() int + Message() string + Error() string +} + +type ErrorFrame struct { + protocol.FrameHeader + + Cod int + Messag string +} + +func (e ErrorFrame) Code() int { + return e.Cod +} + +func (e ErrorFrame) Message() string { + return e.Messag +} + +func (e ErrorFrame) Error() string { + return e.Message() +} + +func (e ErrorFrame) String() string { + return fmt.Sprintf("[error code=%x message=%q]", e.Cod, e.Messag) +} + +type RequestErrUnavailable struct { + ErrorFrame + Consistency protocol.Consistency + Required int + Alive int +} + +func (e *RequestErrUnavailable) String() string { + return fmt.Sprintf("[request_error_unavailable consistency=%s required=%d alive=%d]", e.Consistency, e.Required, e.Alive) +} + +type ErrorMap map[string]uint16 + +type RequestErrWriteTimeout struct { + ErrorFrame + Consistency protocol.Consistency + Received int + BlockFor int + WriteType string +} + +type RequestErrWriteFailure struct { + ErrorFrame + Consistency protocol.Consistency + Received int + BlockFor int + NumFailures int + WriteType string + ErrorMap ErrorMap +} + +type RequestErrCDCWriteFailure struct { + ErrorFrame +} + +type RequestErrReadTimeout struct { + ErrorFrame + Consistency protocol.Consistency + Received int + BlockFor int + DataPresent byte +} + +type RequestErrAlreadyExists struct { + ErrorFrame + Keyspace string + Table string +} + +type RequestErrUnprepared struct { + ErrorFrame + StatementId []byte +} + +type RequestErrReadFailure struct { + ErrorFrame + Consistency protocol.Consistency + Received int + BlockFor int + NumFailures int + DataPresent bool + ErrorMap ErrorMap +} + +type RequestErrFunctionFailure struct { + ErrorFrame + Keyspace string + Function string + ArgTypes []string +} + +// RequestErrCASWriteUnknown is distinct error for ErrCodeCasWriteUnknown. +// +// See https://github.com/apache/cassandra/blob/7337fc0/doc/native_protocol_v5.spec#L1387-L1397 +type RequestErrCASWriteUnknown struct { + ErrorFrame + Consistency protocol.Consistency + Received int + BlockFor int +} diff --git a/internal/logger/logger.go b/internal/logger/logger.go new file mode 100644 index 000000000..c6ff414fe --- /dev/null +++ b/internal/logger/logger.go @@ -0,0 +1,40 @@ +package logger + +import ( + "bytes" + "fmt" + "log" +) + +type StdLogger interface { + Print(v ...interface{}) + Printf(format string, v ...interface{}) + Println(v ...interface{}) +} + +type NopLogger struct{} + +func (n NopLogger) Print(_ ...interface{}) {} + +func (n NopLogger) Printf(_ string, _ ...interface{}) {} + +func (n NopLogger) Println(_ ...interface{}) {} + +type TestLogger struct { + capture bytes.Buffer +} + +func (l *TestLogger) Print(v ...interface{}) { fmt.Fprint(&l.capture, v...) } +func (l *TestLogger) Printf(format string, v ...interface{}) { fmt.Fprintf(&l.capture, format, v...) } +func (l *TestLogger) Println(v ...interface{}) { fmt.Fprintln(&l.capture, v...) } +func (l *TestLogger) String() string { return l.capture.String() } + +type DefaultLogger struct{} + +func (l *DefaultLogger) Print(v ...interface{}) { log.Print(v...) } +func (l *DefaultLogger) Printf(format string, v ...interface{}) { log.Printf(format, v...) } +func (l *DefaultLogger) Println(v ...interface{}) { log.Println(v...) } + +// Logger for logging messages. +// Deprecated: Use ClusterConfig.Logger instead. +var Logger StdLogger = &DefaultLogger{} diff --git a/internal/protocol/codec.go b/internal/protocol/codec.go new file mode 100644 index 000000000..843231782 --- /dev/null +++ b/internal/protocol/codec.go @@ -0,0 +1,2633 @@ +package protocol + +import ( + "bytes" + "encoding/binary" + "errors" + "fmt" + "gopkg.in/inf.v0" + "math" + "math/big" + "math/bits" + "net" + "reflect" + "strconv" + "strings" + "time" +) + +var ( + bigOne = big.NewInt(1) + emptyValue reflect.Value +) + +var ( + ErrorUDTUnavailable = errors.New("UDT are not available on protocols less than 3, please update config") +) + +func IsNullableValue(value interface{}) bool { + v := reflect.ValueOf(value) + return v.Kind() == reflect.Ptr && v.Type().Elem().Kind() == reflect.Ptr +} + +// Marshaler is the interface implemented by objects that can marshal +// themselves into values understood by Cassandra. +type Marshaler interface { + MarshalCQL(info TypeInfo) ([]byte, error) +} + +// Unmarshaler is the interface implemented by objects that can unmarshal +// a Cassandra specific description of themselves. +type Unmarshaler interface { + UnmarshalCQL(info TypeInfo, data []byte) error +} + +func isNullData(info TypeInfo, data []byte) bool { + return data == nil +} + +func UnmarshalNullable(info TypeInfo, data []byte, value interface{}) error { + valueRef := reflect.ValueOf(value) + + if isNullData(info, data) { + nilValue := reflect.Zero(valueRef.Type().Elem()) + valueRef.Elem().Set(nilValue) + return nil + } + + newValue := reflect.New(valueRef.Type().Elem().Elem()) + valueRef.Elem().Set(newValue) + return Unmarshal(info, data, newValue.Interface()) +} + +func MarshalVarchar(info TypeInfo, value interface{}) ([]byte, error) { + switch v := value.(type) { + case Marshaler: + return v.MarshalCQL(info) + case UnsetColumn: + return nil, nil + case string: + return []byte(v), nil + case []byte: + return v, nil + } + + if value == nil { + return nil, nil + } + + rv := reflect.ValueOf(value) + t := rv.Type() + k := t.Kind() + switch { + case k == reflect.String: + return []byte(rv.String()), nil + case k == reflect.Slice && t.Elem().Kind() == reflect.Uint8: + return rv.Bytes(), nil + } + return nil, marshalErrorf("can not marshal %T into %s", value, info) +} + +func UnmarshalVarchar(info TypeInfo, data []byte, value interface{}) error { + switch v := value.(type) { + case Unmarshaler: + return v.UnmarshalCQL(info, data) + case *string: + *v = string(data) + return nil + case *[]byte: + if data != nil { + *v = append((*v)[:0], data...) + } else { + *v = nil + } + return nil + } + + rv := reflect.ValueOf(value) + if rv.Kind() != reflect.Ptr { + return UnmarshalErrorf("can not unmarshal into non-pointer %T", value) + } + rv = rv.Elem() + t := rv.Type() + k := t.Kind() + switch { + case k == reflect.String: + rv.SetString(string(data)) + return nil + case k == reflect.Slice && t.Elem().Kind() == reflect.Uint8: + var dataCopy []byte + if data != nil { + dataCopy = make([]byte, len(data)) + copy(dataCopy, data) + } + rv.SetBytes(dataCopy) + return nil + } + return UnmarshalErrorf("can not unmarshal %s into %T", info, value) +} + +func MarshalSmallInt(info TypeInfo, value interface{}) ([]byte, error) { + switch v := value.(type) { + case Marshaler: + return v.MarshalCQL(info) + case UnsetColumn: + return nil, nil + case int16: + return encShort(v), nil + case uint16: + return encShort(int16(v)), nil + case int8: + return encShort(int16(v)), nil + case uint8: + return encShort(int16(v)), nil + case int: + if v > math.MaxInt16 || v < math.MinInt16 { + return nil, marshalErrorf("marshal smallint: value %d out of range", v) + } + return encShort(int16(v)), nil + case int32: + if v > math.MaxInt16 || v < math.MinInt16 { + return nil, marshalErrorf("marshal smallint: value %d out of range", v) + } + return encShort(int16(v)), nil + case int64: + if v > math.MaxInt16 || v < math.MinInt16 { + return nil, marshalErrorf("marshal smallint: value %d out of range", v) + } + return encShort(int16(v)), nil + case uint: + if v > math.MaxUint16 { + return nil, marshalErrorf("marshal smallint: value %d out of range", v) + } + return encShort(int16(v)), nil + case uint32: + if v > math.MaxUint16 { + return nil, marshalErrorf("marshal smallint: value %d out of range", v) + } + return encShort(int16(v)), nil + case uint64: + if v > math.MaxUint16 { + return nil, marshalErrorf("marshal smallint: value %d out of range", v) + } + return encShort(int16(v)), nil + case string: + n, err := strconv.ParseInt(v, 10, 16) + if err != nil { + return nil, marshalErrorf("can not marshal %T into %s: %v", value, info, err) + } + return encShort(int16(n)), nil + } + + if value == nil { + return nil, nil + } + + switch rv := reflect.ValueOf(value); rv.Type().Kind() { + case reflect.Int, reflect.Int64, reflect.Int32, reflect.Int16, reflect.Int8: + v := rv.Int() + if v > math.MaxInt16 || v < math.MinInt16 { + return nil, marshalErrorf("marshal smallint: value %d out of range", v) + } + return encShort(int16(v)), nil + case reflect.Uint, reflect.Uint64, reflect.Uint32, reflect.Uint16, reflect.Uint8: + v := rv.Uint() + if v > math.MaxUint16 { + return nil, marshalErrorf("marshal smallint: value %d out of range", v) + } + return encShort(int16(v)), nil + case reflect.Ptr: + if rv.IsNil() { + return nil, nil + } + } + + return nil, marshalErrorf("can not marshal %T into %s", value, info) +} + +func MarshalTinyInt(info TypeInfo, value interface{}) ([]byte, error) { + switch v := value.(type) { + case Marshaler: + return v.MarshalCQL(info) + case UnsetColumn: + return nil, nil + case int8: + return []byte{byte(v)}, nil + case uint8: + return []byte{byte(v)}, nil + case int16: + if v > math.MaxInt8 || v < math.MinInt8 { + return nil, marshalErrorf("marshal tinyint: value %d out of range", v) + } + return []byte{byte(v)}, nil + case uint16: + if v > math.MaxUint8 { + return nil, marshalErrorf("marshal tinyint: value %d out of range", v) + } + return []byte{byte(v)}, nil + case int: + if v > math.MaxInt8 || v < math.MinInt8 { + return nil, marshalErrorf("marshal tinyint: value %d out of range", v) + } + return []byte{byte(v)}, nil + case int32: + if v > math.MaxInt8 || v < math.MinInt8 { + return nil, marshalErrorf("marshal tinyint: value %d out of range", v) + } + return []byte{byte(v)}, nil + case int64: + if v > math.MaxInt8 || v < math.MinInt8 { + return nil, marshalErrorf("marshal tinyint: value %d out of range", v) + } + return []byte{byte(v)}, nil + case uint: + if v > math.MaxUint8 { + return nil, marshalErrorf("marshal tinyint: value %d out of range", v) + } + return []byte{byte(v)}, nil + case uint32: + if v > math.MaxUint8 { + return nil, marshalErrorf("marshal tinyint: value %d out of range", v) + } + return []byte{byte(v)}, nil + case uint64: + if v > math.MaxUint8 { + return nil, marshalErrorf("marshal tinyint: value %d out of range", v) + } + return []byte{byte(v)}, nil + case string: + n, err := strconv.ParseInt(v, 10, 8) + if err != nil { + return nil, marshalErrorf("can not marshal %T into %s: %v", value, info, err) + } + return []byte{byte(n)}, nil + } + + if value == nil { + return nil, nil + } + + switch rv := reflect.ValueOf(value); rv.Type().Kind() { + case reflect.Int, reflect.Int64, reflect.Int32, reflect.Int16, reflect.Int8: + v := rv.Int() + if v > math.MaxInt8 || v < math.MinInt8 { + return nil, marshalErrorf("marshal tinyint: value %d out of range", v) + } + return []byte{byte(v)}, nil + case reflect.Uint, reflect.Uint64, reflect.Uint32, reflect.Uint16, reflect.Uint8: + v := rv.Uint() + if v > math.MaxUint8 { + return nil, marshalErrorf("marshal tinyint: value %d out of range", v) + } + return []byte{byte(v)}, nil + case reflect.Ptr: + if rv.IsNil() { + return nil, nil + } + } + + return nil, marshalErrorf("can not marshal %T into %s", value, info) +} + +func MarshalInt(info TypeInfo, value interface{}) ([]byte, error) { + switch v := value.(type) { + case Marshaler: + return v.MarshalCQL(info) + case UnsetColumn: + return nil, nil + case int: + if v > math.MaxInt32 || v < math.MinInt32 { + return nil, marshalErrorf("marshal int: value %d out of range", v) + } + return encInt(int32(v)), nil + case uint: + if v > math.MaxUint32 { + return nil, marshalErrorf("marshal int: value %d out of range", v) + } + return encInt(int32(v)), nil + case int64: + if v > math.MaxInt32 || v < math.MinInt32 { + return nil, marshalErrorf("marshal int: value %d out of range", v) + } + return encInt(int32(v)), nil + case uint64: + if v > math.MaxUint32 { + return nil, marshalErrorf("marshal int: value %d out of range", v) + } + return encInt(int32(v)), nil + case int32: + return encInt(v), nil + case uint32: + return encInt(int32(v)), nil + case int16: + return encInt(int32(v)), nil + case uint16: + return encInt(int32(v)), nil + case int8: + return encInt(int32(v)), nil + case uint8: + return encInt(int32(v)), nil + case string: + i, err := strconv.ParseInt(v, 10, 32) + if err != nil { + return nil, marshalErrorf("can not marshal string to int: %s", err) + } + return encInt(int32(i)), nil + } + + if value == nil { + return nil, nil + } + + switch rv := reflect.ValueOf(value); rv.Type().Kind() { + case reflect.Int, reflect.Int64, reflect.Int32, reflect.Int16, reflect.Int8: + v := rv.Int() + if v > math.MaxInt32 || v < math.MinInt32 { + return nil, marshalErrorf("marshal int: value %d out of range", v) + } + return encInt(int32(v)), nil + case reflect.Uint, reflect.Uint64, reflect.Uint32, reflect.Uint16, reflect.Uint8: + v := rv.Uint() + if v > math.MaxInt32 { + return nil, marshalErrorf("marshal int: value %d out of range", v) + } + return encInt(int32(v)), nil + case reflect.Ptr: + if rv.IsNil() { + return nil, nil + } + } + + return nil, marshalErrorf("can not marshal %T into %s", value, info) +} + +func encInt(x int32) []byte { + return []byte{byte(x >> 24), byte(x >> 16), byte(x >> 8), byte(x)} +} + +func decInt(x []byte) int32 { + if len(x) != 4 { + return 0 + } + return int32(x[0])<<24 | int32(x[1])<<16 | int32(x[2])<<8 | int32(x[3]) +} + +func encShort(x int16) []byte { + p := make([]byte, 2) + p[0] = byte(x >> 8) + p[1] = byte(x) + return p +} + +func decShort(p []byte) int16 { + if len(p) != 2 { + return 0 + } + return int16(p[0])<<8 | int16(p[1]) +} + +func decTiny(p []byte) int8 { + if len(p) != 1 { + return 0 + } + return int8(p[0]) +} + +func MarshalBigInt(info TypeInfo, value interface{}) ([]byte, error) { + switch v := value.(type) { + case Marshaler: + return v.MarshalCQL(info) + case UnsetColumn: + return nil, nil + case int: + return encBigInt(int64(v)), nil + case uint: + if uint64(v) > math.MaxInt64 { + return nil, marshalErrorf("marshal bigint: value %d out of range", v) + } + return encBigInt(int64(v)), nil + case int64: + return encBigInt(v), nil + case uint64: + return encBigInt(int64(v)), nil + case int32: + return encBigInt(int64(v)), nil + case uint32: + return encBigInt(int64(v)), nil + case int16: + return encBigInt(int64(v)), nil + case uint16: + return encBigInt(int64(v)), nil + case int8: + return encBigInt(int64(v)), nil + case uint8: + return encBigInt(int64(v)), nil + case big.Int: + return encBigInt2C(&v), nil + case string: + i, err := strconv.ParseInt(value.(string), 10, 64) + if err != nil { + return nil, marshalErrorf("can not marshal string to bigint: %s", err) + } + return encBigInt(i), nil + } + + if value == nil { + return nil, nil + } + + rv := reflect.ValueOf(value) + switch rv.Type().Kind() { + case reflect.Int, reflect.Int64, reflect.Int32, reflect.Int16, reflect.Int8: + v := rv.Int() + return encBigInt(v), nil + case reflect.Uint, reflect.Uint64, reflect.Uint32, reflect.Uint16, reflect.Uint8: + v := rv.Uint() + if v > math.MaxInt64 { + return nil, marshalErrorf("marshal bigint: value %d out of range", v) + } + return encBigInt(int64(v)), nil + } + return nil, marshalErrorf("can not marshal %T into %s", value, info) +} + +func encBigInt(x int64) []byte { + return []byte{byte(x >> 56), byte(x >> 48), byte(x >> 40), byte(x >> 32), + byte(x >> 24), byte(x >> 16), byte(x >> 8), byte(x)} +} + +func bytesToInt64(data []byte) (ret int64) { + for i := range data { + ret |= int64(data[i]) << (8 * uint(len(data)-i-1)) + } + return ret +} + +func bytesToUint64(data []byte) (ret uint64) { + for i := range data { + ret |= uint64(data[i]) << (8 * uint(len(data)-i-1)) + } + return ret +} + +func UnmarshalBigInt(info TypeInfo, data []byte, value interface{}) error { + return UnmarshalIntlike(info, decBigInt(data), data, value) +} + +func UnmarshalInt(info TypeInfo, data []byte, value interface{}) error { + return UnmarshalIntlike(info, int64(decInt(data)), data, value) +} + +func UnmarshalSmallInt(info TypeInfo, data []byte, value interface{}) error { + return UnmarshalIntlike(info, int64(decShort(data)), data, value) +} + +func UnmarshalTinyInt(info TypeInfo, data []byte, value interface{}) error { + return UnmarshalIntlike(info, int64(decTiny(data)), data, value) +} + +func UnmarshalVarint(info TypeInfo, data []byte, value interface{}) error { + switch v := value.(type) { + case *big.Int: + return UnmarshalIntlike(info, 0, data, value) + case *uint64: + if len(data) == 9 && data[0] == 0 { + *v = bytesToUint64(data[1:]) + return nil + } + } + + if len(data) > 8 { + return UnmarshalErrorf("unmarshal int: varint value %v out of range for %T (use big.Int)", data, value) + } + + int64Val := bytesToInt64(data) + if len(data) > 0 && len(data) < 8 && data[0]&0x80 > 0 { + int64Val -= (1 << uint(len(data)*8)) + } + return UnmarshalIntlike(info, int64Val, data, value) +} + +func MarshalVarint(info TypeInfo, value interface{}) ([]byte, error) { + var ( + retBytes []byte + err error + ) + + switch v := value.(type) { + case UnsetColumn: + return nil, nil + case uint64: + if v > uint64(math.MaxInt64) { + retBytes = make([]byte, 9) + binary.BigEndian.PutUint64(retBytes[1:], v) + } else { + retBytes = make([]byte, 8) + binary.BigEndian.PutUint64(retBytes, v) + } + default: + retBytes, err = MarshalBigInt(info, value) + } + + if err == nil { + // trim down to most significant byte + i := 0 + for ; i < len(retBytes)-1; i++ { + b0 := retBytes[i] + if b0 != 0 && b0 != 0xFF { + break + } + + b1 := retBytes[i+1] + if b0 == 0 && b1 != 0 { + if b1&0x80 == 0 { + i++ + } + break + } + + if b0 == 0xFF && b1 != 0xFF { + if b1&0x80 > 0 { + i++ + } + break + } + } + retBytes = retBytes[i:] + } + + return retBytes, err +} + +func UnmarshalIntlike(info TypeInfo, int64Val int64, data []byte, value interface{}) error { + switch v := value.(type) { + case *int: + if ^uint(0) == math.MaxUint32 && (int64Val < math.MinInt32 || int64Val > math.MaxInt32) { + return UnmarshalErrorf("unmarshal int: value %d out of range for %T", int64Val, *v) + } + *v = int(int64Val) + return nil + case *uint: + unitVal := uint64(int64Val) + switch info.Type() { + case TypeInt: + *v = uint(unitVal) & 0xFFFFFFFF + case TypeSmallInt: + *v = uint(unitVal) & 0xFFFF + case TypeTinyInt: + *v = uint(unitVal) & 0xFF + default: + if ^uint(0) == math.MaxUint32 && (int64Val < 0 || int64Val > math.MaxUint32) { + return UnmarshalErrorf("unmarshal int: value %d out of range for %T", unitVal, *v) + } + *v = uint(unitVal) + } + return nil + case *int64: + *v = int64Val + return nil + case *uint64: + switch info.Type() { + case TypeInt: + *v = uint64(int64Val) & 0xFFFFFFFF + case TypeSmallInt: + *v = uint64(int64Val) & 0xFFFF + case TypeTinyInt: + *v = uint64(int64Val) & 0xFF + default: + *v = uint64(int64Val) + } + return nil + case *int32: + if int64Val < math.MinInt32 || int64Val > math.MaxInt32 { + return UnmarshalErrorf("unmarshal int: value %d out of range for %T", int64Val, *v) + } + *v = int32(int64Val) + return nil + case *uint32: + switch info.Type() { + case TypeInt: + *v = uint32(int64Val) & 0xFFFFFFFF + case TypeSmallInt: + *v = uint32(int64Val) & 0xFFFF + case TypeTinyInt: + *v = uint32(int64Val) & 0xFF + default: + if int64Val < 0 || int64Val > math.MaxUint32 { + return UnmarshalErrorf("unmarshal int: value %d out of range for %T", int64Val, *v) + } + *v = uint32(int64Val) & 0xFFFFFFFF + } + return nil + case *int16: + if int64Val < math.MinInt16 || int64Val > math.MaxInt16 { + return UnmarshalErrorf("unmarshal int: value %d out of range for %T", int64Val, *v) + } + *v = int16(int64Val) + return nil + case *uint16: + switch info.Type() { + case TypeSmallInt: + *v = uint16(int64Val) & 0xFFFF + case TypeTinyInt: + *v = uint16(int64Val) & 0xFF + default: + if int64Val < 0 || int64Val > math.MaxUint16 { + return UnmarshalErrorf("unmarshal int: value %d out of range for %T", int64Val, *v) + } + *v = uint16(int64Val) & 0xFFFF + } + return nil + case *int8: + if int64Val < math.MinInt8 || int64Val > math.MaxInt8 { + return UnmarshalErrorf("unmarshal int: value %d out of range for %T", int64Val, *v) + } + *v = int8(int64Val) + return nil + case *uint8: + if info.Type() != TypeTinyInt && (int64Val < 0 || int64Val > math.MaxUint8) { + return UnmarshalErrorf("unmarshal int: value %d out of range for %T", int64Val, *v) + } + *v = uint8(int64Val) & 0xFF + return nil + case *big.Int: + decBigInt2C(data, v) + return nil + case *string: + *v = strconv.FormatInt(int64Val, 10) + return nil + } + + rv := reflect.ValueOf(value) + if rv.Kind() != reflect.Ptr { + return UnmarshalErrorf("can not unmarshal into non-pointer %T", value) + } + rv = rv.Elem() + + switch rv.Type().Kind() { + case reflect.Int: + if ^uint(0) == math.MaxUint32 && (int64Val < math.MinInt32 || int64Val > math.MaxInt32) { + return UnmarshalErrorf("unmarshal int: value %d out of range", int64Val) + } + rv.SetInt(int64Val) + return nil + case reflect.Int64: + rv.SetInt(int64Val) + return nil + case reflect.Int32: + if int64Val < math.MinInt32 || int64Val > math.MaxInt32 { + return UnmarshalErrorf("unmarshal int: value %d out of range", int64Val) + } + rv.SetInt(int64Val) + return nil + case reflect.Int16: + if int64Val < math.MinInt16 || int64Val > math.MaxInt16 { + return UnmarshalErrorf("unmarshal int: value %d out of range", int64Val) + } + rv.SetInt(int64Val) + return nil + case reflect.Int8: + if int64Val < math.MinInt8 || int64Val > math.MaxInt8 { + return UnmarshalErrorf("unmarshal int: value %d out of range", int64Val) + } + rv.SetInt(int64Val) + return nil + case reflect.Uint: + unitVal := uint64(int64Val) + switch info.Type() { + case TypeInt: + rv.SetUint(unitVal & 0xFFFFFFFF) + case TypeSmallInt: + rv.SetUint(unitVal & 0xFFFF) + case TypeTinyInt: + rv.SetUint(unitVal & 0xFF) + default: + if ^uint(0) == math.MaxUint32 && (int64Val < 0 || int64Val > math.MaxUint32) { + return UnmarshalErrorf("unmarshal int: value %d out of range for %s", unitVal, rv.Type()) + } + rv.SetUint(unitVal) + } + return nil + case reflect.Uint64: + unitVal := uint64(int64Val) + switch info.Type() { + case TypeInt: + rv.SetUint(unitVal & 0xFFFFFFFF) + case TypeSmallInt: + rv.SetUint(unitVal & 0xFFFF) + case TypeTinyInt: + rv.SetUint(unitVal & 0xFF) + default: + rv.SetUint(unitVal) + } + return nil + case reflect.Uint32: + unitVal := uint64(int64Val) + switch info.Type() { + case TypeInt: + rv.SetUint(unitVal & 0xFFFFFFFF) + case TypeSmallInt: + rv.SetUint(unitVal & 0xFFFF) + case TypeTinyInt: + rv.SetUint(unitVal & 0xFF) + default: + if int64Val < 0 || int64Val > math.MaxUint32 { + return UnmarshalErrorf("unmarshal int: value %d out of range for %s", int64Val, rv.Type()) + } + rv.SetUint(unitVal & 0xFFFFFFFF) + } + return nil + case reflect.Uint16: + unitVal := uint64(int64Val) + switch info.Type() { + case TypeSmallInt: + rv.SetUint(unitVal & 0xFFFF) + case TypeTinyInt: + rv.SetUint(unitVal & 0xFF) + default: + if int64Val < 0 || int64Val > math.MaxUint16 { + return UnmarshalErrorf("unmarshal int: value %d out of range for %s", int64Val, rv.Type()) + } + rv.SetUint(unitVal & 0xFFFF) + } + return nil + case reflect.Uint8: + if info.Type() != TypeTinyInt && (int64Val < 0 || int64Val > math.MaxUint8) { + return UnmarshalErrorf("unmarshal int: value %d out of range for %s", int64Val, rv.Type()) + } + rv.SetUint(uint64(int64Val) & 0xff) + return nil + } + return UnmarshalErrorf("can not unmarshal %s into %T", info, value) +} + +func decBigInt(data []byte) int64 { + if len(data) != 8 { + return 0 + } + return int64(data[0])<<56 | int64(data[1])<<48 | + int64(data[2])<<40 | int64(data[3])<<32 | + int64(data[4])<<24 | int64(data[5])<<16 | + int64(data[6])<<8 | int64(data[7]) +} + +func MarshalBool(info TypeInfo, value interface{}) ([]byte, error) { + switch v := value.(type) { + case Marshaler: + return v.MarshalCQL(info) + case UnsetColumn: + return nil, nil + case bool: + return encBool(v), nil + } + + if value == nil { + return nil, nil + } + + rv := reflect.ValueOf(value) + switch rv.Type().Kind() { + case reflect.Bool: + return encBool(rv.Bool()), nil + } + return nil, marshalErrorf("can not marshal %T into %s", value, info) +} + +func encBool(v bool) []byte { + if v { + return []byte{1} + } + return []byte{0} +} + +func UnmarshalBool(info TypeInfo, data []byte, value interface{}) error { + switch v := value.(type) { + case Unmarshaler: + return v.UnmarshalCQL(info, data) + case *bool: + *v = decBool(data) + return nil + } + rv := reflect.ValueOf(value) + if rv.Kind() != reflect.Ptr { + return UnmarshalErrorf("can not unmarshal into non-pointer %T", value) + } + rv = rv.Elem() + switch rv.Type().Kind() { + case reflect.Bool: + rv.SetBool(decBool(data)) + return nil + } + return UnmarshalErrorf("can not unmarshal %s into %T", info, value) +} + +func decBool(v []byte) bool { + if len(v) == 0 { + return false + } + return v[0] != 0 +} + +func MarshalFloat(info TypeInfo, value interface{}) ([]byte, error) { + switch v := value.(type) { + case Marshaler: + return v.MarshalCQL(info) + case UnsetColumn: + return nil, nil + case float32: + return encInt(int32(math.Float32bits(v))), nil + } + + if value == nil { + return nil, nil + } + + rv := reflect.ValueOf(value) + switch rv.Type().Kind() { + case reflect.Float32: + return encInt(int32(math.Float32bits(float32(rv.Float())))), nil + } + return nil, marshalErrorf("can not marshal %T into %s", value, info) +} + +func UnmarshalFloat(info TypeInfo, data []byte, value interface{}) error { + switch v := value.(type) { + case Unmarshaler: + return v.UnmarshalCQL(info, data) + case *float32: + *v = math.Float32frombits(uint32(decInt(data))) + return nil + } + rv := reflect.ValueOf(value) + if rv.Kind() != reflect.Ptr { + return UnmarshalErrorf("can not unmarshal into non-pointer %T", value) + } + rv = rv.Elem() + switch rv.Type().Kind() { + case reflect.Float32: + rv.SetFloat(float64(math.Float32frombits(uint32(decInt(data))))) + return nil + } + return UnmarshalErrorf("can not unmarshal %s into %T", info, value) +} + +func MarshalDouble(info TypeInfo, value interface{}) ([]byte, error) { + switch v := value.(type) { + case Marshaler: + return v.MarshalCQL(info) + case UnsetColumn: + return nil, nil + case float64: + return encBigInt(int64(math.Float64bits(v))), nil + } + if value == nil { + return nil, nil + } + rv := reflect.ValueOf(value) + switch rv.Type().Kind() { + case reflect.Float64: + return encBigInt(int64(math.Float64bits(rv.Float()))), nil + } + return nil, marshalErrorf("can not marshal %T into %s", value, info) +} + +func UnmarshalDouble(info TypeInfo, data []byte, value interface{}) error { + switch v := value.(type) { + case Unmarshaler: + return v.UnmarshalCQL(info, data) + case *float64: + *v = math.Float64frombits(uint64(decBigInt(data))) + return nil + } + rv := reflect.ValueOf(value) + if rv.Kind() != reflect.Ptr { + return UnmarshalErrorf("can not unmarshal into non-pointer %T", value) + } + rv = rv.Elem() + switch rv.Type().Kind() { + case reflect.Float64: + rv.SetFloat(math.Float64frombits(uint64(decBigInt(data)))) + return nil + } + return UnmarshalErrorf("can not unmarshal %s into %T", info, value) +} + +func MarshalDecimal(info TypeInfo, value interface{}) ([]byte, error) { + if value == nil { + return nil, nil + } + + switch v := value.(type) { + case Marshaler: + return v.MarshalCQL(info) + case UnsetColumn: + return nil, nil + case inf.Dec: + unscaled := encBigInt2C(v.UnscaledBig()) + if unscaled == nil { + return nil, marshalErrorf("can not marshal %T into %s", value, info) + } + + buf := make([]byte, 4+len(unscaled)) + copy(buf[0:4], encInt(int32(v.Scale()))) + copy(buf[4:], unscaled) + return buf, nil + } + return nil, marshalErrorf("can not marshal %T into %s", value, info) +} + +func UnmarshalDecimal(info TypeInfo, data []byte, value interface{}) error { + switch v := value.(type) { + case Unmarshaler: + return v.UnmarshalCQL(info, data) + case *inf.Dec: + if len(data) < 4 { + return UnmarshalErrorf("inf.Dec needs at least 4 bytes, while value has only %d", len(data)) + } + scale := decInt(data[0:4]) + unscaled := decBigInt2C(data[4:], nil) + *v = *inf.NewDecBig(unscaled, inf.Scale(scale)) + return nil + } + return UnmarshalErrorf("can not unmarshal %s into %T", info, value) +} + +// decBigInt2C sets the value of n to the big-endian two's complement +// value stored in the given data. If data[0]&80 != 0, the number +// is negative. If data is empty, the result will be 0. +func decBigInt2C(data []byte, n *big.Int) *big.Int { + if n == nil { + n = new(big.Int) + } + n.SetBytes(data) + if len(data) > 0 && data[0]&0x80 > 0 { + n.Sub(n, new(big.Int).Lsh(bigOne, uint(len(data))*8)) + } + return n +} + +// encBigInt2C returns the big-endian two's complement +// form of n. +func encBigInt2C(n *big.Int) []byte { + switch n.Sign() { + case 0: + return []byte{0} + case 1: + b := n.Bytes() + if b[0]&0x80 > 0 { + b = append([]byte{0}, b...) + } + return b + case -1: + length := uint(n.BitLen()/8+1) * 8 + b := new(big.Int).Add(n, new(big.Int).Lsh(bigOne, length)).Bytes() + // When the most significant bit is on a byte + // boundary, we can get some extra significant + // bits, so strip them off when that happens. + if len(b) >= 2 && b[0] == 0xff && b[1]&0x80 != 0 { + b = b[1:] + } + return b + } + return nil +} + +func MarshalTime(info TypeInfo, value interface{}) ([]byte, error) { + switch v := value.(type) { + case Marshaler: + return v.MarshalCQL(info) + case UnsetColumn: + return nil, nil + case int64: + return encBigInt(v), nil + case time.Duration: + return encBigInt(v.Nanoseconds()), nil + } + + if value == nil { + return nil, nil + } + + rv := reflect.ValueOf(value) + switch rv.Type().Kind() { + case reflect.Int64: + return encBigInt(rv.Int()), nil + } + return nil, marshalErrorf("can not marshal %T into %s", value, info) +} + +func MarshalTimestamp(info TypeInfo, value interface{}) ([]byte, error) { + switch v := value.(type) { + case Marshaler: + return v.MarshalCQL(info) + case UnsetColumn: + return nil, nil + case int64: + return encBigInt(v), nil + case time.Time: + if v.IsZero() { + return []byte{}, nil + } + x := int64(v.UTC().Unix()*1e3) + int64(v.UTC().Nanosecond()/1e6) + return encBigInt(x), nil + } + + if value == nil { + return nil, nil + } + + rv := reflect.ValueOf(value) + switch rv.Type().Kind() { + case reflect.Int64: + return encBigInt(rv.Int()), nil + } + return nil, marshalErrorf("can not marshal %T into %s", value, info) +} + +func UnmarshalTime(info TypeInfo, data []byte, value interface{}) error { + switch v := value.(type) { + case Unmarshaler: + return v.UnmarshalCQL(info, data) + case *int64: + *v = decBigInt(data) + return nil + case *time.Duration: + *v = time.Duration(decBigInt(data)) + return nil + } + + rv := reflect.ValueOf(value) + if rv.Kind() != reflect.Ptr { + return UnmarshalErrorf("can not unmarshal into non-pointer %T", value) + } + rv = rv.Elem() + switch rv.Type().Kind() { + case reflect.Int64: + rv.SetInt(decBigInt(data)) + return nil + } + return UnmarshalErrorf("can not unmarshal %s into %T", info, value) +} + +func UnmarshalTimestamp(info TypeInfo, data []byte, value interface{}) error { + switch v := value.(type) { + case Unmarshaler: + return v.UnmarshalCQL(info, data) + case *int64: + *v = decBigInt(data) + return nil + case *time.Time: + if len(data) == 0 { + *v = time.Time{} + return nil + } + x := decBigInt(data) + sec := x / 1000 + nsec := (x - sec*1000) * 1000000 + *v = time.Unix(sec, nsec).In(time.UTC) + return nil + } + + rv := reflect.ValueOf(value) + if rv.Kind() != reflect.Ptr { + return UnmarshalErrorf("can not unmarshal into non-pointer %T", value) + } + rv = rv.Elem() + switch rv.Type().Kind() { + case reflect.Int64: + rv.SetInt(decBigInt(data)) + return nil + } + return UnmarshalErrorf("can not unmarshal %s into %T", info, value) +} + +const millisecondsInADay int64 = 24 * 60 * 60 * 1000 + +func MarshalDate(info TypeInfo, value interface{}) ([]byte, error) { + var timestamp int64 + switch v := value.(type) { + case Marshaler: + return v.MarshalCQL(info) + case UnsetColumn: + return nil, nil + case int64: + timestamp = v + x := timestamp/millisecondsInADay + int64(1<<31) + return encInt(int32(x)), nil + case time.Time: + if v.IsZero() { + return []byte{}, nil + } + timestamp = int64(v.UTC().Unix()*1e3) + int64(v.UTC().Nanosecond()/1e6) + x := timestamp/millisecondsInADay + int64(1<<31) + return encInt(int32(x)), nil + case *time.Time: + if v.IsZero() { + return []byte{}, nil + } + timestamp = int64(v.UTC().Unix()*1e3) + int64(v.UTC().Nanosecond()/1e6) + x := timestamp/millisecondsInADay + int64(1<<31) + return encInt(int32(x)), nil + case string: + if v == "" { + return []byte{}, nil + } + t, err := time.Parse("2006-01-02", v) + if err != nil { + return nil, marshalErrorf("can not marshal %T into %s, date layout must be '2006-01-02'", value, info) + } + timestamp = int64(t.UTC().Unix()*1e3) + int64(t.UTC().Nanosecond()/1e6) + x := timestamp/millisecondsInADay + int64(1<<31) + return encInt(int32(x)), nil + } + + if value == nil { + return nil, nil + } + return nil, marshalErrorf("can not marshal %T into %s", value, info) +} + +func UnmarshalDate(info TypeInfo, data []byte, value interface{}) error { + switch v := value.(type) { + case Unmarshaler: + return v.UnmarshalCQL(info, data) + case *time.Time: + if len(data) == 0 { + *v = time.Time{} + return nil + } + var origin uint32 = 1 << 31 + var current uint32 = binary.BigEndian.Uint32(data) + timestamp := (int64(current) - int64(origin)) * millisecondsInADay + *v = time.UnixMilli(timestamp).In(time.UTC) + return nil + case *string: + if len(data) == 0 { + *v = "" + return nil + } + var origin uint32 = 1 << 31 + var current uint32 = binary.BigEndian.Uint32(data) + timestamp := (int64(current) - int64(origin)) * millisecondsInADay + *v = time.UnixMilli(timestamp).In(time.UTC).Format("2006-01-02") + return nil + } + return UnmarshalErrorf("can not unmarshal %s into %T", info, value) +} + +func MarshalDuration(info TypeInfo, value interface{}) ([]byte, error) { + switch v := value.(type) { + case Marshaler: + return v.MarshalCQL(info) + case UnsetColumn: + return nil, nil + case int64: + return encVints(0, 0, v), nil + case time.Duration: + return encVints(0, 0, v.Nanoseconds()), nil + case string: + d, err := time.ParseDuration(v) + if err != nil { + return nil, err + } + return encVints(0, 0, d.Nanoseconds()), nil + case Duration: + return encVints(v.Months, v.Days, v.Nanoseconds), nil + } + + if value == nil { + return nil, nil + } + + rv := reflect.ValueOf(value) + switch rv.Type().Kind() { + case reflect.Int64: + return encBigInt(rv.Int()), nil + } + return nil, marshalErrorf("can not marshal %T into %s", value, info) +} + +func UnmarshalDuration(info TypeInfo, data []byte, value interface{}) error { + switch v := value.(type) { + case Unmarshaler: + return v.UnmarshalCQL(info, data) + case *Duration: + if len(data) == 0 { + *v = Duration{ + Months: 0, + Days: 0, + Nanoseconds: 0, + } + return nil + } + months, days, nanos, err := decVints(data) + if err != nil { + return UnmarshalErrorf("failed to unmarshal %s into %T: %s", info, value, err.Error()) + } + *v = Duration{ + Months: months, + Days: days, + Nanoseconds: nanos, + } + return nil + } + return UnmarshalErrorf("can not unmarshal %s into %T", info, value) +} + +func decVints(data []byte) (int32, int32, int64, error) { + month, i, err := decVint(data, 0) + if err != nil { + return 0, 0, 0, fmt.Errorf("failed to extract month: %s", err.Error()) + } + days, i, err := decVint(data, i) + if err != nil { + return 0, 0, 0, fmt.Errorf("failed to extract days: %s", err.Error()) + } + nanos, _, err := decVint(data, i) + if err != nil { + return 0, 0, 0, fmt.Errorf("failed to extract nanoseconds: %s", err.Error()) + } + return int32(month), int32(days), nanos, err +} + +func decVint(data []byte, start int) (int64, int, error) { + if len(data) <= start { + return 0, 0, errors.New("unexpected eof") + } + firstByte := data[start] + if firstByte&0x80 == 0 { + return decIntZigZag(uint64(firstByte)), start + 1, nil + } + numBytes := bits.LeadingZeros32(uint32(^firstByte)) - 24 + ret := uint64(firstByte & (0xff >> uint(numBytes))) + if len(data) < start+numBytes+1 { + return 0, 0, fmt.Errorf("data expect to have %d bytes, but it has only %d", start+numBytes+1, len(data)) + } + for i := start; i < start+numBytes; i++ { + ret <<= 8 + ret |= uint64(data[i+1] & 0xff) + } + return decIntZigZag(ret), start + numBytes + 1, nil +} + +func decIntZigZag(n uint64) int64 { + return int64((n >> 1) ^ -(n & 1)) +} + +func encIntZigZag(n int64) uint64 { + return uint64((n >> 63) ^ (n << 1)) +} + +func encVints(months int32, seconds int32, nanos int64) []byte { + buf := append(encVint(int64(months)), encVint(int64(seconds))...) + return append(buf, encVint(nanos)...) +} + +func encVint(v int64) []byte { + vEnc := encIntZigZag(v) + lead0 := bits.LeadingZeros64(vEnc) + numBytes := (639 - lead0*9) >> 6 + + // It can be 1 or 0 is v ==0 + if numBytes <= 1 { + return []byte{byte(vEnc)} + } + extraBytes := numBytes - 1 + var buf = make([]byte, numBytes) + for i := extraBytes; i >= 0; i-- { + buf[i] = byte(vEnc) + vEnc >>= 8 + } + buf[0] |= byte(^(0xff >> uint(extraBytes))) + return buf +} + +func writeCollectionSize(info CollectionType, n int, buf *bytes.Buffer) error { + if info.proto > ProtoVersion2 { + if n > math.MaxInt32 { + return marshalErrorf("marshal: collection too large") + } + + buf.WriteByte(byte(n >> 24)) + buf.WriteByte(byte(n >> 16)) + buf.WriteByte(byte(n >> 8)) + buf.WriteByte(byte(n)) + } else { + if n > math.MaxUint16 { + return marshalErrorf("marshal: collection too large") + } + + buf.WriteByte(byte(n >> 8)) + buf.WriteByte(byte(n)) + } + + return nil +} + +func MarshalList(info TypeInfo, value interface{}) ([]byte, error) { + listInfo, ok := info.(CollectionType) + if !ok { + return nil, marshalErrorf("marshal: can not marshal non collection type into list") + } + + if value == nil { + return nil, nil + } else if _, ok := value.(UnsetColumn); ok { + return nil, nil + } + + rv := reflect.ValueOf(value) + t := rv.Type() + k := t.Kind() + if k == reflect.Slice && rv.IsNil() { + return nil, nil + } + + switch k { + case reflect.Slice, reflect.Array: + buf := &bytes.Buffer{} + n := rv.Len() + + if err := writeCollectionSize(listInfo, n, buf); err != nil { + return nil, err + } + + for i := 0; i < n; i++ { + item, err := Marshal(listInfo.Elem, rv.Index(i).Interface()) + if err != nil { + return nil, err + } + itemLen := len(item) + // Set the value to null for supported protocols + if item == nil && listInfo.proto > ProtoVersion2 { + itemLen = -1 + } + if err := writeCollectionSize(listInfo, itemLen, buf); err != nil { + return nil, err + } + buf.Write(item) + } + return buf.Bytes(), nil + case reflect.Map: + elem := t.Elem() + if elem.Kind() == reflect.Struct && elem.NumField() == 0 { + rkeys := rv.MapKeys() + keys := make([]interface{}, len(rkeys)) + for i := 0; i < len(keys); i++ { + keys[i] = rkeys[i].Interface() + } + return MarshalList(listInfo, keys) + } + } + return nil, marshalErrorf("can not marshal %T into %s", value, info) +} + +func readCollectionSize(info CollectionType, data []byte) (size, read int, err error) { + if info.proto > ProtoVersion2 { + if len(data) < 4 { + return 0, 0, UnmarshalErrorf("unmarshal list: unexpected eof") + } + size = int(int32(data[0])<<24 | int32(data[1])<<16 | int32(data[2])<<8 | int32(data[3])) + read = 4 + } else { + if len(data) < 2 { + return 0, 0, UnmarshalErrorf("unmarshal list: unexpected eof") + } + size = int(data[0])<<8 | int(data[1]) + read = 2 + } + return +} + +func UnmarshalList(info TypeInfo, data []byte, value interface{}) error { + listInfo, ok := info.(CollectionType) + if !ok { + return UnmarshalErrorf("unmarshal: can not unmarshal none collection type into list") + } + + rv := reflect.ValueOf(value) + if rv.Kind() != reflect.Ptr { + return UnmarshalErrorf("can not unmarshal into non-pointer %T", value) + } + rv = rv.Elem() + t := rv.Type() + k := t.Kind() + + switch k { + case reflect.Slice, reflect.Array: + if data == nil { + if k == reflect.Array { + return UnmarshalErrorf("unmarshal list: can not store nil in array value") + } + if rv.IsNil() { + return nil + } + rv.Set(reflect.Zero(t)) + return nil + } + n, p, err := readCollectionSize(listInfo, data) + if err != nil { + return err + } + data = data[p:] + if k == reflect.Array { + if rv.Len() != n { + return UnmarshalErrorf("unmarshal list: array with wrong size") + } + } else { + rv.Set(reflect.MakeSlice(t, n, n)) + } + for i := 0; i < n; i++ { + m, p, err := readCollectionSize(listInfo, data) + if err != nil { + return err + } + data = data[p:] + // In case m < 0, the value is null, and unmarshalData should be nil. + var unmarshalData []byte + if m >= 0 { + if len(data) < m { + return UnmarshalErrorf("unmarshal list: unexpected eof") + } + unmarshalData = data[:m] + data = data[m:] + } + if err := Unmarshal(listInfo.Elem, unmarshalData, rv.Index(i).Addr().Interface()); err != nil { + return err + } + } + return nil + } + return UnmarshalErrorf("can not unmarshal %s into %T", info, value) +} + +func MarshalMap(info TypeInfo, value interface{}) ([]byte, error) { + mapInfo, ok := info.(CollectionType) + if !ok { + return nil, marshalErrorf("marshal: can not marshal none collection type into map") + } + + if value == nil { + return nil, nil + } else if _, ok := value.(UnsetColumn); ok { + return nil, nil + } + + rv := reflect.ValueOf(value) + + t := rv.Type() + if t.Kind() != reflect.Map { + return nil, marshalErrorf("can not marshal %T into %s", value, info) + } + + if rv.IsNil() { + return nil, nil + } + + buf := &bytes.Buffer{} + n := rv.Len() + + if err := writeCollectionSize(mapInfo, n, buf); err != nil { + return nil, err + } + + keys := rv.MapKeys() + for _, key := range keys { + item, err := Marshal(mapInfo.Key, key.Interface()) + if err != nil { + return nil, err + } + itemLen := len(item) + // Set the key to null for supported protocols + if item == nil && mapInfo.proto > ProtoVersion2 { + itemLen = -1 + } + if err := writeCollectionSize(mapInfo, itemLen, buf); err != nil { + return nil, err + } + buf.Write(item) + + item, err = Marshal(mapInfo.Elem, rv.MapIndex(key).Interface()) + if err != nil { + return nil, err + } + itemLen = len(item) + // Set the value to null for supported protocols + if item == nil && mapInfo.proto > ProtoVersion2 { + itemLen = -1 + } + if err := writeCollectionSize(mapInfo, itemLen, buf); err != nil { + return nil, err + } + buf.Write(item) + } + return buf.Bytes(), nil +} + +func UnmarshalMap(info TypeInfo, data []byte, value interface{}) error { + mapInfo, ok := info.(CollectionType) + if !ok { + return UnmarshalErrorf("unmarshal: can not unmarshal none collection type into map") + } + + rv := reflect.ValueOf(value) + if rv.Kind() != reflect.Ptr { + return UnmarshalErrorf("can not unmarshal into non-pointer %T", value) + } + rv = rv.Elem() + t := rv.Type() + if t.Kind() != reflect.Map { + return UnmarshalErrorf("can not unmarshal %s into %T", info, value) + } + if data == nil { + rv.Set(reflect.Zero(t)) + return nil + } + n, p, err := readCollectionSize(mapInfo, data) + if err != nil { + return err + } + if n < 0 { + return UnmarshalErrorf("negative map size %d", n) + } + rv.Set(reflect.MakeMapWithSize(t, n)) + data = data[p:] + for i := 0; i < n; i++ { + m, p, err := readCollectionSize(mapInfo, data) + if err != nil { + return err + } + data = data[p:] + key := reflect.New(t.Key()) + // In case m < 0, the key is null, and unmarshalData should be nil. + var unmarshalData []byte + if m >= 0 { + if len(data) < m { + return UnmarshalErrorf("unmarshal map: unexpected eof") + } + unmarshalData = data[:m] + data = data[m:] + } + if err := Unmarshal(mapInfo.Key, unmarshalData, key.Interface()); err != nil { + return err + } + + m, p, err = readCollectionSize(mapInfo, data) + if err != nil { + return err + } + data = data[p:] + val := reflect.New(t.Elem()) + + // In case m < 0, the value is null, and unmarshalData should be nil. + unmarshalData = nil + if m >= 0 { + if len(data) < m { + return UnmarshalErrorf("unmarshal map: unexpected eof") + } + unmarshalData = data[:m] + data = data[m:] + } + if err := Unmarshal(mapInfo.Elem, unmarshalData, val.Interface()); err != nil { + return err + } + + rv.SetMapIndex(key.Elem(), val.Elem()) + } + return nil +} + +func MarshalUUID(info TypeInfo, value interface{}) ([]byte, error) { + switch val := value.(type) { + case UnsetColumn: + return nil, nil + case UUID: + return val.Bytes(), nil + case [16]byte: + return val[:], nil + case []byte: + if len(val) != 16 { + return nil, marshalErrorf("can not marshal []byte %d bytes long into %s, must be exactly 16 bytes long", len(val), info) + } + return val, nil + case string: + b, err := ParseUUID(val) + if err != nil { + return nil, err + } + return b[:], nil + } + + if value == nil { + return nil, nil + } + + return nil, marshalErrorf("can not marshal %T into %s", value, info) +} + +func UnmarshalUUID(info TypeInfo, data []byte, value interface{}) error { + if len(data) == 0 { + switch v := value.(type) { + case *string: + *v = "" + case *[]byte: + *v = nil + case *UUID: + *v = UUID{} + default: + return UnmarshalErrorf("can not unmarshal X %s into %T", info, value) + } + + return nil + } + + if len(data) != 16 { + return UnmarshalErrorf("unable to parse UUID: UUIDs must be exactly 16 bytes long") + } + + switch v := value.(type) { + case *[16]byte: + copy((*v)[:], data) + return nil + case *UUID: + copy((*v)[:], data) + return nil + } + + u, err := UUIDFromBytes(data) + if err != nil { + return UnmarshalErrorf("unable to parse UUID: %s", err) + } + + switch v := value.(type) { + case *string: + *v = u.String() + return nil + case *[]byte: + *v = u[:] + return nil + } + return UnmarshalErrorf("can not unmarshal X %s into %T", info, value) +} + +func UnmarshalTimeUUID(info TypeInfo, data []byte, value interface{}) error { + switch v := value.(type) { + case Unmarshaler: + return v.UnmarshalCQL(info, data) + case *time.Time: + id, err := UUIDFromBytes(data) + if err != nil { + return err + } else if id.Version() != 1 { + return UnmarshalErrorf("invalid timeuuid") + } + *v = id.Time() + return nil + default: + return UnmarshalUUID(info, data, value) + } +} + +func MarshalInet(info TypeInfo, value interface{}) ([]byte, error) { + // we return either the 4 or 16 byte representation of an + // ip address here otherwise the db value will be prefixed + // with the remaining byte values e.g. ::ffff:127.0.0.1 and not 127.0.0.1 + switch val := value.(type) { + case UnsetColumn: + return nil, nil + case net.IP: + t := val.To4() + if t == nil { + return val.To16(), nil + } + return t, nil + case string: + b := net.ParseIP(val) + if b != nil { + t := b.To4() + if t == nil { + return b.To16(), nil + } + return t, nil + } + return nil, marshalErrorf("cannot marshal. invalid ip string %s", val) + } + + if value == nil { + return nil, nil + } + + return nil, marshalErrorf("cannot marshal %T into %s", value, info) +} + +func UnmarshalInet(info TypeInfo, data []byte, value interface{}) error { + switch v := value.(type) { + case Unmarshaler: + return v.UnmarshalCQL(info, data) + case *net.IP: + if x := len(data); !(x == 4 || x == 16) { + return UnmarshalErrorf("cannot unmarshal %s into %T: invalid sized IP: got %d bytes not 4 or 16", info, value, x) + } + buf := CopyBytes(data) + ip := net.IP(buf) + if v4 := ip.To4(); v4 != nil { + *v = v4 + return nil + } + *v = ip + return nil + case *string: + if len(data) == 0 { + *v = "" + return nil + } + ip := net.IP(data) + if v4 := ip.To4(); v4 != nil { + *v = v4.String() + return nil + } + *v = ip.String() + return nil + } + return UnmarshalErrorf("cannot unmarshal %s into %T", info, value) +} + +func MarshalTuple(info TypeInfo, value interface{}) ([]byte, error) { + tuple := info.(TupleTypeInfo) + switch v := value.(type) { + case UnsetColumn: + return nil, UnmarshalErrorf("Invalid request: UnsetValue is unsupported for tuples") + case []interface{}: + if len(v) != len(tuple.Elems) { + return nil, UnmarshalErrorf("cannont marshal tuple: wrong number of elements") + } + + var buf []byte + for i, elem := range v { + if elem == nil { + buf = appendInt(buf, int32(-1)) + continue + } + + data, err := Marshal(tuple.Elems[i], elem) + if err != nil { + return nil, err + } + + n := len(data) + buf = appendInt(buf, int32(n)) + buf = append(buf, data...) + } + + return buf, nil + } + + rv := reflect.ValueOf(value) + t := rv.Type() + k := t.Kind() + + switch k { + case reflect.Struct: + if v := t.NumField(); v != len(tuple.Elems) { + return nil, marshalErrorf("can not marshal tuple into struct %v, not enough fields have %d need %d", t, v, len(tuple.Elems)) + } + + var buf []byte + for i, elem := range tuple.Elems { + field := rv.Field(i) + + if field.Kind() == reflect.Ptr && field.IsNil() { + buf = appendInt(buf, int32(-1)) + continue + } + + data, err := Marshal(elem, field.Interface()) + if err != nil { + return nil, err + } + + n := len(data) + buf = appendInt(buf, int32(n)) + buf = append(buf, data...) + } + + return buf, nil + case reflect.Slice, reflect.Array: + size := rv.Len() + if size != len(tuple.Elems) { + return nil, marshalErrorf("can not marshal tuple into %v of length %d need %d elements", k, size, len(tuple.Elems)) + } + + var buf []byte + for i, elem := range tuple.Elems { + item := rv.Index(i) + + if item.Kind() == reflect.Ptr && item.IsNil() { + buf = appendInt(buf, int32(-1)) + continue + } + + data, err := Marshal(elem, item.Interface()) + if err != nil { + return nil, err + } + + n := len(data) + buf = appendInt(buf, int32(n)) + buf = append(buf, data...) + } + + return buf, nil + } + + return nil, marshalErrorf("cannot marshal %T into %s", value, tuple) +} + +func readBytes(p []byte) ([]byte, []byte) { + // TODO: really should use a framer + size := ReadInt(p) + p = p[4:] + if size < 0 { + return nil, p + } + return p[:size], p[size:] +} + +// currently only support unmarshal into a list of values, this makes it possible +// to support tuples without changing the query API. In the future this can be extend +// to allow unmarshalling into custom tuple types. +func UnmarshalTuple(info TypeInfo, data []byte, value interface{}) error { + if v, ok := value.(Unmarshaler); ok { + return v.UnmarshalCQL(info, data) + } + + tuple := info.(TupleTypeInfo) + switch v := value.(type) { + case []interface{}: + for i, elem := range tuple.Elems { + // each element inside data is a [bytes] + var p []byte + if len(data) >= 4 { + p, data = readBytes(data) + } + err := Unmarshal(elem, p, v[i]) + if err != nil { + return err + } + } + + return nil + } + + rv := reflect.ValueOf(value) + if rv.Kind() != reflect.Ptr { + return UnmarshalErrorf("can not unmarshal into non-pointer %T", value) + } + + rv = rv.Elem() + t := rv.Type() + k := t.Kind() + + switch k { + case reflect.Struct: + if v := t.NumField(); v != len(tuple.Elems) { + return UnmarshalErrorf("can not unmarshal tuple into struct %v, not enough fields have %d need %d", t, v, len(tuple.Elems)) + } + + for i, elem := range tuple.Elems { + var p []byte + if len(data) >= 4 { + p, data = readBytes(data) + } + + v, err := elem.NewWithError() + if err != nil { + return err + } + if err := Unmarshal(elem, p, v); err != nil { + return err + } + + switch rv.Field(i).Kind() { + case reflect.Ptr: + if p != nil { + rv.Field(i).Set(reflect.ValueOf(v)) + } else { + rv.Field(i).Set(reflect.Zero(reflect.TypeOf(v))) + } + default: + rv.Field(i).Set(reflect.ValueOf(v).Elem()) + } + } + + return nil + case reflect.Slice, reflect.Array: + if k == reflect.Array { + size := rv.Len() + if size != len(tuple.Elems) { + return UnmarshalErrorf("can not unmarshal tuple into array of length %d need %d elements", size, len(tuple.Elems)) + } + } else { + rv.Set(reflect.MakeSlice(t, len(tuple.Elems), len(tuple.Elems))) + } + + for i, elem := range tuple.Elems { + var p []byte + if len(data) >= 4 { + p, data = readBytes(data) + } + + v, err := elem.NewWithError() + if err != nil { + return err + } + if err := Unmarshal(elem, p, v); err != nil { + return err + } + + switch rv.Index(i).Kind() { + case reflect.Ptr: + if p != nil { + rv.Index(i).Set(reflect.ValueOf(v)) + } else { + rv.Index(i).Set(reflect.Zero(reflect.TypeOf(v))) + } + default: + rv.Index(i).Set(reflect.ValueOf(v).Elem()) + } + } + + return nil + } + + return UnmarshalErrorf("cannot unmarshal %s into %T", info, value) +} + +// UDTMarshaler is an interface which should be implemented by users wishing to +// handle encoding UDT types to sent to Cassandra. Note: due to current implentations +// methods defined for this interface must be value receivers not pointer receivers. +type UDTMarshaler interface { + // MarshalUDT will be called for each field in the the UDT returned by Cassandra, + // the implementor should marshal the type to return by for example calling + // Marshal. + MarshalUDT(name string, info TypeInfo) ([]byte, error) +} + +func MarshalUDT(info TypeInfo, value interface{}) ([]byte, error) { + udt := info.(UDTTypeInfo) + + switch v := value.(type) { + case Marshaler: + return v.MarshalCQL(info) + case UnsetColumn: + return nil, UnmarshalErrorf("invalid request: UnsetValue is unsupported for user defined types") + case UDTMarshaler: + var buf []byte + for _, e := range udt.Elements { + data, err := v.MarshalUDT(e.Name, e.Type) + if err != nil { + return nil, err + } + + buf = appendBytes(buf, data) + } + + return buf, nil + case map[string]interface{}: + var buf []byte + for _, e := range udt.Elements { + val, ok := v[e.Name] + + var data []byte + + if ok { + var err error + data, err = Marshal(e.Type, val) + if err != nil { + return nil, err + } + } + + buf = appendBytes(buf, data) + } + + return buf, nil + } + + k := reflect.ValueOf(value) + if k.Kind() == reflect.Ptr { + if k.IsNil() { + return nil, marshalErrorf("cannot marshal %T into %s", value, info) + } + k = k.Elem() + } + + if k.Kind() != reflect.Struct || !k.IsValid() { + return nil, marshalErrorf("cannot marshal %T into %s", value, info) + } + + fields := make(map[string]reflect.Value) + t := reflect.TypeOf(value) + for i := 0; i < t.NumField(); i++ { + sf := t.Field(i) + + if tag := sf.Tag.Get("cql"); tag != "" { + fields[tag] = k.Field(i) + } + } + + var buf []byte + for _, e := range udt.Elements { + f, ok := fields[e.Name] + if !ok { + f = k.FieldByName(e.Name) + } + + var data []byte + if f.IsValid() && f.CanInterface() { + var err error + data, err = Marshal(e.Type, f.Interface()) + if err != nil { + return nil, err + } + } + + buf = appendBytes(buf, data) + } + + return buf, nil +} + +// UDTUnmarshaler should be implemented by users wanting to implement custom +// UDT unmarshaling. +type UDTUnmarshaler interface { + // UnmarshalUDT will be called for each field in the UDT return by Cassandra, + // the implementor should unmarshal the data into the value of their chosing, + // for example by calling Unmarshal. + UnmarshalUDT(name string, info TypeInfo, data []byte) error +} + +func UnmarshalUDT(info TypeInfo, data []byte, value interface{}) error { + switch v := value.(type) { + case Unmarshaler: + return v.UnmarshalCQL(info, data) + case UDTUnmarshaler: + udt := info.(UDTTypeInfo) + + for id, e := range udt.Elements { + if len(data) == 0 { + return nil + } + if len(data) < 4 { + return UnmarshalErrorf("can not unmarshal %s: field [%d]%s: unexpected eof", info, id, e.Name) + } + + var p []byte + p, data = readBytes(data) + if err := v.UnmarshalUDT(e.Name, e.Type, p); err != nil { + return err + } + } + + return nil + case *map[string]interface{}: + udt := info.(UDTTypeInfo) + + rv := reflect.ValueOf(value) + if rv.Kind() != reflect.Ptr { + return UnmarshalErrorf("can not unmarshal into non-pointer %T", value) + } + + rv = rv.Elem() + t := rv.Type() + if t.Kind() != reflect.Map { + return UnmarshalErrorf("can not unmarshal %s into %T", info, value) + } else if data == nil { + rv.Set(reflect.Zero(t)) + return nil + } + + rv.Set(reflect.MakeMap(t)) + m := *v + + for id, e := range udt.Elements { + if len(data) == 0 { + return nil + } + if len(data) < 4 { + return UnmarshalErrorf("can not unmarshal %s: field [%d]%s: unexpected eof", info, id, e.Name) + } + + valType, err := goType(e.Type) + if err != nil { + return UnmarshalErrorf("can not unmarshal %s: %v", info, err) + } + + val := reflect.New(valType) + + var p []byte + p, data = readBytes(data) + + if err := Unmarshal(e.Type, p, val.Interface()); err != nil { + return err + } + + m[e.Name] = val.Elem().Interface() + } + + return nil + } + + rv := reflect.ValueOf(value) + if rv.Kind() != reflect.Ptr { + return UnmarshalErrorf("can not unmarshal into non-pointer %T", value) + } + k := rv.Elem() + if k.Kind() != reflect.Struct || !k.IsValid() { + return UnmarshalErrorf("cannot unmarshal %s into %T", info, value) + } + + if len(data) == 0 { + if k.CanSet() { + k.Set(reflect.Zero(k.Type())) + } + + return nil + } + + t := k.Type() + fields := make(map[string]reflect.Value, t.NumField()) + for i := 0; i < t.NumField(); i++ { + sf := t.Field(i) + + if tag := sf.Tag.Get("cql"); tag != "" { + fields[tag] = k.Field(i) + } + } + + udt := info.(UDTTypeInfo) + for id, e := range udt.Elements { + if len(data) == 0 { + return nil + } + if len(data) < 4 { + // UDT def does not match the column value + return UnmarshalErrorf("can not unmarshal %s: field [%d]%s: unexpected eof", info, id, e.Name) + } + + var p []byte + p, data = readBytes(data) + + f, ok := fields[e.Name] + if !ok { + f = k.FieldByName(e.Name) + if f == emptyValue { + // skip fields which exist in the UDT but not in + // the struct passed in + continue + } + } + + if !f.IsValid() || !f.CanAddr() { + return UnmarshalErrorf("cannot unmarshal %s into %T: field %v is not valid", info, value, e.Name) + } + + fk := f.Addr().Interface() + if err := Unmarshal(e.Type, p, fk); err != nil { + return err + } + } + + return nil +} + +type MarshalError string + +func (m MarshalError) Error() string { + return string(m) +} + +func marshalErrorf(format string, args ...interface{}) MarshalError { + return MarshalError(fmt.Sprintf(format, args...)) +} + +type UnmarshalError string + +func (m UnmarshalError) Error() string { + return string(m) +} + +func UnmarshalErrorf(format string, args ...interface{}) UnmarshalError { + return UnmarshalError(fmt.Sprintf(format, args...)) +} + +func Marshal(info TypeInfo, value interface{}) ([]byte, error) { + if info.Version() < ProtoVersion1 { + panic("protocol version not set") + } + + if valueRef := reflect.ValueOf(value); valueRef.Kind() == reflect.Ptr { + if valueRef.IsNil() { + return nil, nil + } else if v, ok := value.(Marshaler); ok { + return v.MarshalCQL(info) + } else { + return Marshal(info, valueRef.Elem().Interface()) + } + } + + if v, ok := value.(Marshaler); ok { + return v.MarshalCQL(info) + } + + switch info.Type() { + case TypeVarchar, TypeAscii, TypeBlob, TypeText: + return MarshalVarchar(info, value) + case TypeBoolean: + return MarshalBool(info, value) + case TypeTinyInt: + return MarshalTinyInt(info, value) + case TypeSmallInt: + return MarshalSmallInt(info, value) + case TypeInt: + return MarshalInt(info, value) + case TypeBigInt, TypeCounter: + return MarshalBigInt(info, value) + case TypeFloat: + return MarshalFloat(info, value) + case TypeDouble: + return MarshalDouble(info, value) + case TypeDecimal: + return MarshalDecimal(info, value) + case TypeTime: + return MarshalTime(info, value) + case TypeTimestamp: + return MarshalTimestamp(info, value) + case TypeList, TypeSet: + return MarshalList(info, value) + case TypeMap: + return MarshalMap(info, value) + case TypeUUID, TypeTimeUUID: + return MarshalUUID(info, value) + case TypeVarint: + return MarshalVarint(info, value) + case TypeInet: + return MarshalInet(info, value) + case TypeTuple: + return MarshalTuple(info, value) + case TypeUDT: + return MarshalUDT(info, value) + case TypeDate: + return MarshalDate(info, value) + case TypeDuration: + return MarshalDuration(info, value) + } + + // detect protocol 2 UDT + if strings.HasPrefix(info.Custom(), "org.apache.cassandra.db.marshal.UserType") && info.Version() < 3 { + return nil, ErrorUDTUnavailable + } + + // TODO(tux21b): add the remaining types + return nil, fmt.Errorf("can not marshal %T into %s", value, info) +} + +func Unmarshal(info TypeInfo, data []byte, value interface{}) error { + if v, ok := value.(Unmarshaler); ok { + return v.UnmarshalCQL(info, data) + } + + if IsNullableValue(value) { + return UnmarshalNullable(info, data, value) + } + + switch info.Type() { + case TypeVarchar, TypeAscii, TypeBlob, TypeText: + return UnmarshalVarchar(info, data, value) + case TypeBoolean: + return UnmarshalBool(info, data, value) + case TypeInt: + return UnmarshalInt(info, data, value) + case TypeBigInt, TypeCounter: + return UnmarshalBigInt(info, data, value) + case TypeVarint: + return UnmarshalVarint(info, data, value) + case TypeSmallInt: + return UnmarshalSmallInt(info, data, value) + case TypeTinyInt: + return UnmarshalTinyInt(info, data, value) + case TypeFloat: + return UnmarshalFloat(info, data, value) + case TypeDouble: + return UnmarshalDouble(info, data, value) + case TypeDecimal: + return UnmarshalDecimal(info, data, value) + case TypeTime: + return UnmarshalTime(info, data, value) + case TypeTimestamp: + return UnmarshalTimestamp(info, data, value) + case TypeList, TypeSet: + return UnmarshalList(info, data, value) + case TypeMap: + return UnmarshalMap(info, data, value) + case TypeTimeUUID: + return UnmarshalTimeUUID(info, data, value) + case TypeUUID: + return UnmarshalUUID(info, data, value) + case TypeInet: + return UnmarshalInet(info, data, value) + case TypeTuple: + return UnmarshalTuple(info, data, value) + case TypeUDT: + return UnmarshalUDT(info, data, value) + case TypeDate: + return UnmarshalDate(info, data, value) + case TypeDuration: + return UnmarshalDuration(info, data, value) + } + + // detect protocol 2 UDT + if strings.HasPrefix(info.Custom(), "org.apache.cassandra.db.marshal.UserType") && info.Version() < 3 { + return ErrorUDTUnavailable + } + + // TODO(tux21b): add the remaining types + return fmt.Errorf("can not unmarshal %s into %T", info, value) +} + +// TypeInfo describes a Cassandra specific data type. +type TypeInfo interface { + Type() Type + Version() byte + Custom() string + + // New creates a pointer to an empty version of whatever type + // is referenced by the TypeInfo receiver. + // + // If there is no corresponding Go type for the CQL type, New panics. + // + // Deprecated: Use NewWithError instead. + New() interface{} + + // NewWithError creates a pointer to an empty version of whatever type + // is referenced by the TypeInfo receiver. + // + // If there is no corresponding Go type for the CQL type, NewWithError returns an error. + NewWithError() (interface{}, error) +} + +type NativeType struct { + proto byte + Typ Type + Cust string // only used for TypeCustom +} + +func NewNativeType(proto byte, typ Type, custom string) NativeType { + return NativeType{proto, typ, custom} +} + +func (t NativeType) NewWithError() (interface{}, error) { + typ, err := goType(t) + if err != nil { + return nil, err + } + return reflect.New(typ).Interface(), nil +} + +func (t NativeType) New() interface{} { + val, err := t.NewWithError() + if err != nil { + panic(err.Error()) + } + return val +} + +func (s NativeType) Type() Type { + return s.Typ +} + +func (s NativeType) Version() byte { + return s.proto +} + +func (s NativeType) Custom() string { + return s.Cust +} + +func (s NativeType) String() string { + switch s.Typ { + case TypeCustom: + return fmt.Sprintf("%s(%s)", s.Typ, s.Cust) + default: + return s.Typ.String() + } +} + +type CollectionType struct { + NativeType + Key TypeInfo // only used for TypeMap + Elem TypeInfo // only used for TypeMap, TypeList and TypeSet +} + +func (t CollectionType) NewWithError() (interface{}, error) { + typ, err := goType(t) + if err != nil { + return nil, err + } + return reflect.New(typ).Interface(), nil +} + +func (t CollectionType) New() interface{} { + val, err := t.NewWithError() + if err != nil { + panic(err.Error()) + } + return val +} + +func (c CollectionType) String() string { + switch c.Typ { + case TypeMap: + return fmt.Sprintf("%s(%s, %s)", c.Typ, c.Key, c.Elem) + case TypeList, TypeSet: + return fmt.Sprintf("%s(%s)", c.Typ, c.Elem) + case TypeCustom: + return fmt.Sprintf("%s(%s)", c.Typ, c.Cust) + default: + return c.Typ.String() + } +} + +type TupleTypeInfo struct { + NativeType + Elems []TypeInfo +} + +func (t TupleTypeInfo) String() string { + var buf bytes.Buffer + buf.WriteString(fmt.Sprintf("%s(", t.Typ)) + for _, elem := range t.Elems { + buf.WriteString(fmt.Sprintf("%s, ", elem)) + } + buf.Truncate(buf.Len() - 2) + buf.WriteByte(')') + return buf.String() +} + +func (t TupleTypeInfo) NewWithError() (interface{}, error) { + typ, err := goType(t) + if err != nil { + return nil, err + } + return reflect.New(typ).Interface(), nil +} + +func (t TupleTypeInfo) New() interface{} { + val, err := t.NewWithError() + if err != nil { + panic(err.Error()) + } + return val +} + +type UDTField struct { + Name string + Type TypeInfo +} + +type UDTTypeInfo struct { + NativeType + KeySpace string + Name string + Elements []UDTField +} + +func (u UDTTypeInfo) NewWithError() (interface{}, error) { + typ, err := goType(u) + if err != nil { + return nil, err + } + return reflect.New(typ).Interface(), nil +} + +func (u UDTTypeInfo) New() interface{} { + val, err := u.NewWithError() + if err != nil { + panic(err.Error()) + } + return val +} + +func (u UDTTypeInfo) String() string { + buf := &bytes.Buffer{} + + fmt.Fprintf(buf, "%s.%s{", u.KeySpace, u.Name) + first := true + for _, e := range u.Elements { + if !first { + fmt.Fprint(buf, ",") + } else { + first = false + } + + fmt.Fprintf(buf, "%s=%v", e.Name, e.Type) + } + fmt.Fprint(buf, "}") + + return buf.String() +} + +// String returns a human readable name for the Cassandra datatype +// described by t. +// Type is the identifier of a Cassandra internal datatype. +type Type int + +const ( + TypeCustom Type = 0x0000 + TypeAscii Type = 0x0001 + TypeBigInt Type = 0x0002 + TypeBlob Type = 0x0003 + TypeBoolean Type = 0x0004 + TypeCounter Type = 0x0005 + TypeDecimal Type = 0x0006 + TypeDouble Type = 0x0007 + TypeFloat Type = 0x0008 + TypeInt Type = 0x0009 + TypeText Type = 0x000A + TypeTimestamp Type = 0x000B + TypeUUID Type = 0x000C + TypeVarchar Type = 0x000D + TypeVarint Type = 0x000E + TypeTimeUUID Type = 0x000F + TypeInet Type = 0x0010 + TypeDate Type = 0x0011 + TypeTime Type = 0x0012 + TypeSmallInt Type = 0x0013 + TypeTinyInt Type = 0x0014 + TypeDuration Type = 0x0015 + TypeList Type = 0x0020 + TypeMap Type = 0x0021 + TypeSet Type = 0x0022 + TypeUDT Type = 0x0030 + TypeTuple Type = 0x0031 +) + +// String returns the name of the identifier. +func (t Type) String() string { + switch t { + case TypeCustom: + return "custom" + case TypeAscii: + return "ascii" + case TypeBigInt: + return "bigint" + case TypeBlob: + return "blob" + case TypeBoolean: + return "boolean" + case TypeCounter: + return "counter" + case TypeDecimal: + return "decimal" + case TypeDouble: + return "double" + case TypeFloat: + return "float" + case TypeInt: + return "int" + case TypeText: + return "text" + case TypeTimestamp: + return "timestamp" + case TypeUUID: + return "uuid" + case TypeVarchar: + return "varchar" + case TypeTimeUUID: + return "timeuuid" + case TypeInet: + return "inet" + case TypeDate: + return "date" + case TypeDuration: + return "duration" + case TypeTime: + return "time" + case TypeSmallInt: + return "smallint" + case TypeTinyInt: + return "tinyint" + case TypeList: + return "list" + case TypeMap: + return "map" + case TypeSet: + return "set" + case TypeVarint: + return "varint" + case TypeTuple: + return "tuple" + default: + return fmt.Sprintf("unknown_type_%d", t) + } +} diff --git a/internal/protocol/cqltypes.go b/internal/protocol/cqltypes.go new file mode 100644 index 000000000..d966df5f5 --- /dev/null +++ b/internal/protocol/cqltypes.go @@ -0,0 +1,178 @@ +package protocol + +import ( + "errors" + "fmt" + "strings" + "time" +) + +type Duration struct { + Months int32 + Days int32 + Nanoseconds int64 +} + +const ( + VariantNCSCompat = 0 + VariantIETF = 2 + VariantMicrosoft = 6 + VariantFuture = 7 +) + +// ParseUUID parses a 32 digit hexadecimal number (that might contain hypens) +// representing an UUID. +func ParseUUID(input string) (UUID, error) { + var u UUID + j := 0 + for _, r := range input { + switch { + case r == '-' && j&1 == 0: + continue + case r >= '0' && r <= '9' && j < 32: + u[j/2] |= byte(r-'0') << uint(4-j&1*4) + case r >= 'a' && r <= 'f' && j < 32: + u[j/2] |= byte(r-'a'+10) << uint(4-j&1*4) + case r >= 'A' && r <= 'F' && j < 32: + u[j/2] |= byte(r-'A'+10) << uint(4-j&1*4) + default: + return UUID{}, fmt.Errorf("invalid UUID %q", input) + } + j += 1 + } + if j != 32 { + return UUID{}, fmt.Errorf("invalid UUID %q", input) + } + return u, nil +} + +// UUIDFromBytes converts a raw byte slice to an UUID. +func UUIDFromBytes(input []byte) (UUID, error) { + var u UUID + if len(input) != 16 { + return u, errors.New("UUIDs must be exactly 16 bytes long") + } + + copy(u[:], input) + return u, nil +} + +type UUID [16]byte + +var TimeBase = time.Date(1582, time.October, 15, 0, 0, 0, 0, time.UTC).Unix() + +func (u UUID) String() string { + var offsets = [...]int{0, 2, 4, 6, 9, 11, 14, 16, 19, 21, 24, 26, 28, 30, 32, 34} + const hexString = "0123456789abcdef" + r := make([]byte, 36) + for i, b := range u { + r[offsets[i]] = hexString[b>>4] + r[offsets[i]+1] = hexString[b&0xF] + } + r[8] = '-' + r[13] = '-' + r[18] = '-' + r[23] = '-' + return string(r) + +} + +// Bytes returns the raw byte slice for this UUID. A UUID is always 128 bits +// (16 bytes) long. +func (u UUID) Bytes() []byte { + return u[:] +} + +// Variant returns the variant of this UUID. This package will only generate +// UUIDs in the IETF variant. +func (u UUID) Variant() int { + x := u[8] + if x&0x80 == 0 { + return VariantNCSCompat + } + if x&0x40 == 0 { + return VariantIETF + } + if x&0x20 == 0 { + return VariantMicrosoft + } + return VariantFuture +} + +// Version extracts the version of this UUID variant. The RFC 4122 describes +// five kinds of UUIDs. +func (u UUID) Version() int { + return int(u[6] & 0xF0 >> 4) +} + +// Node extracts the MAC address of the node who generated this UUID. It will +// return nil if the UUID is not a time based UUID (version 1). +func (u UUID) Node() []byte { + if u.Version() != 1 { + return nil + } + return u[10:] +} + +// Clock extracts the clock sequence of this UUID. It will return zero if the +// UUID is not a time based UUID (version 1). +func (u UUID) Clock() uint32 { + if u.Version() != 1 { + return 0 + } + + // Clock sequence is the lower 14bits of u[8:10] + return uint32(u[8]&0x3F)<<8 | uint32(u[9]) +} + +// Timestamp extracts the timestamp information from a time based UUID +// (version 1). +func (u UUID) Timestamp() int64 { + if u.Version() != 1 { + return 0 + } + return int64(uint64(u[0])<<24|uint64(u[1])<<16| + uint64(u[2])<<8|uint64(u[3])) + + int64(uint64(u[4])<<40|uint64(u[5])<<32) + + int64(uint64(u[6]&0x0F)<<56|uint64(u[7])<<48) +} + +// Time is like Timestamp, except that it returns a time.Time. +func (u UUID) Time() time.Time { + if u.Version() != 1 { + return time.Time{} + } + t := u.Timestamp() + sec := t / 1e7 + nsec := (t % 1e7) * 100 + return time.Unix(sec+TimeBase, nsec).UTC() +} + +// Marshaling for JSON +func (u UUID) MarshalJSON() ([]byte, error) { + return []byte(`"` + u.String() + `"`), nil +} + +// Unmarshaling for JSON +func (u *UUID) UnmarshalJSON(data []byte) error { + str := strings.Trim(string(data), `"`) + if len(str) > 36 { + return fmt.Errorf("invalid JSON UUID %s", str) + } + + parsed, err := ParseUUID(str) + if err == nil { + copy(u[:], parsed[:]) + } + + return err +} + +func (u UUID) MarshalText() ([]byte, error) { + return []byte(u.String()), nil +} + +func (u *UUID) UnmarshalText(text []byte) (err error) { + *u, err = ParseUUID(string(text)) + return +} diff --git a/internal/protocol/framer.go b/internal/protocol/framer.go new file mode 100644 index 000000000..358d2f6db --- /dev/null +++ b/internal/protocol/framer.go @@ -0,0 +1,2019 @@ +package protocol + +import ( + "errors" + "fmt" + "github.com/gocql/gocql/internal/compressor" + "github.com/gocql/gocql/internal/internal_errors" + "io" + "io/ioutil" + "net" + "runtime" + "strings" + "time" +) + +type UnsetColumn struct{} + +var ( + ErrFrameTooBig = errors.New("frame length is bigger than the maximum allowed") +) + +// UnsetValue represents a value used in a query binding that will be ignored by Cassandra. +// +// By setting a field to the unset value Cassandra will ignore the write completely. +// The main advantage is the ability to keep the same prepared statement even when you don't +// want to update some fields, where before you needed to make another prepared statement. +// +// UnsetValue is only available when using the version 4 of the protocol. +//var UnsetValue = UnsetColumn{} + +type NamedValue struct { + Name string + Value interface{} +} + +const ( + ProtoDirectionMask = 0x80 + ProtoVersionMask = 0x7F + ProtoVersion1 = 0x01 + ProtoVersion2 = 0x02 + ProtoVersion3 = 0x03 + ProtoVersion4 = 0x04 + ProtoVersion5 = 0x05 + + MaxFrameSize = 256 * 1024 * 1024 +) + +type ProtoVersion byte + +func (p ProtoVersion) request() bool { + return p&ProtoDirectionMask == 0x00 +} + +func (p ProtoVersion) response() bool { + return p&ProtoDirectionMask == 0x80 +} + +func (p ProtoVersion) Version() byte { + return byte(p) & ProtoVersionMask +} + +func (p ProtoVersion) String() string { + dir := "REQ" + if p.response() { + dir = "RESP" + } + + return fmt.Sprintf("[version=%d direction=%s]", p.Version(), dir) +} + +type FrameOp byte + +const ( + // header ops + opError FrameOp = 0x00 + opStartup FrameOp = 0x01 + opReady FrameOp = 0x02 + opAuthenticate FrameOp = 0x03 + opOptions FrameOp = 0x05 + opSupported FrameOp = 0x06 + opQuery FrameOp = 0x07 + opResult FrameOp = 0x08 + opPrepare FrameOp = 0x09 + opExecute FrameOp = 0x0A + opRegister FrameOp = 0x0B + opEvent FrameOp = 0x0C + opBatch FrameOp = 0x0D + opAuthChallenge FrameOp = 0x0E + opAuthResponse FrameOp = 0x0F + opAuthSuccess FrameOp = 0x10 +) + +func (f FrameOp) String() string { + switch f { + case opError: + return "ERROR" + case opStartup: + return "STARTUP" + case opReady: + return "READY" + case opAuthenticate: + return "AUTHENTICATE" + case opOptions: + return "OPTIONS" + case opSupported: + return "SUPPORTED" + case opQuery: + return "QUERY" + case opResult: + return "RESULT" + case opPrepare: + return "PREPARE" + case opExecute: + return "EXECUTE" + case opRegister: + return "REGISTER" + case opEvent: + return "EVENT" + case opBatch: + return "BATCH" + case opAuthChallenge: + return "AUTH_CHALLENGE" + case opAuthResponse: + return "AUTH_RESPONSE" + case opAuthSuccess: + return "AUTH_SUCCESS" + default: + return fmt.Sprintf("UNKNOWN_OP_%d", f) + } +} + +const ( + // result kind + resultKindVoid = 1 + resultKindRows = 2 + resultKindKeyspace = 3 + resultKindPrepared = 4 + resultKindSchemaChanged = 5 + + // rows flags + flagGlobalTableSpec int = 0x01 + flagHasMorePages int = 0x02 + flagNoMetaData int = 0x04 + + // query flags + flagValues byte = 0x01 + flagSkipMetaData byte = 0x02 + flagPageSize byte = 0x04 + flagWithPagingState byte = 0x08 + flagWithSerialConsistency byte = 0x10 + flagDefaultTimestamp byte = 0x20 + flagWithNameValues byte = 0x40 + flagWithKeyspace byte = 0x80 + + // prepare flags + flagWithPreparedKeyspace uint32 = 0x01 + + // header flags + flagCompress byte = 0x01 + flagTracing byte = 0x02 + flagCustomPayload byte = 0x04 + flagWarning byte = 0x08 + flagBetaProtocol byte = 0x10 +) + +type Consistency uint16 + +const ( + Any Consistency = 0x00 + One Consistency = 0x01 + Two Consistency = 0x02 + Three Consistency = 0x03 + Quorum Consistency = 0x04 + All Consistency = 0x05 + LocalQuorum Consistency = 0x06 + EachQuorum Consistency = 0x07 + LocalOne Consistency = 0x0A +) + +func (c Consistency) String() string { + switch c { + case Any: + return "ANY" + case One: + return "ONE" + case Two: + return "TWO" + case Three: + return "THREE" + case Quorum: + return "QUORUM" + case All: + return "ALL" + case LocalQuorum: + return "LOCAL_QUORUM" + case EachQuorum: + return "EACH_QUORUM" + case LocalOne: + return "LOCAL_ONE" + default: + return fmt.Sprintf("UNKNOWN_CONS_0x%x", uint16(c)) + } +} + +func (c Consistency) MarshalText() (text []byte, err error) { + return []byte(c.String()), nil +} + +func (c *Consistency) UnmarshalText(text []byte) error { + switch string(text) { + case "ANY": + *c = Any + case "ONE": + *c = One + case "TWO": + *c = Two + case "THREE": + *c = Three + case "QUORUM": + *c = Quorum + case "ALL": + *c = All + case "LOCAL_QUORUM": + *c = LocalQuorum + case "EACH_QUORUM": + *c = EachQuorum + case "LOCAL_ONE": + *c = LocalOne + default: + return fmt.Errorf("invalid consistency %q", string(text)) + } + + return nil +} + +func ParseConsistency(s string) Consistency { + var c Consistency + if err := c.UnmarshalText([]byte(strings.ToUpper(s))); err != nil { + panic(err) + } + return c +} + +// ParseConsistencyWrapper wraps gocql.ParseConsistency to provide an err +// return instead of a panic +func ParseConsistencyWrapper(s string) (consistency Consistency, err error) { + err = consistency.UnmarshalText([]byte(strings.ToUpper(s))) + return +} + +// MustParseConsistency is the same as ParseConsistency except it returns +// an error (never). It is kept here since breaking changes are not good. +// DEPRECATED: use ParseConsistency if you want a panic on parse error. +func MustParseConsistency(s string) (Consistency, error) { + c, err := ParseConsistencyWrapper(s) + if err != nil { + panic(err) + } + return c, nil +} + +type SerialConsistency uint16 + +const ( + Serial SerialConsistency = 0x08 + LocalSerial SerialConsistency = 0x09 +) + +func (s SerialConsistency) String() string { + switch s { + case Serial: + return "SERIAL" + case LocalSerial: + return "LOCAL_SERIAL" + default: + return fmt.Sprintf("UNKNOWN_SERIAL_CONS_0x%x", uint16(s)) + } +} + +func (s SerialConsistency) MarshalText() (text []byte, err error) { + return []byte(s.String()), nil +} + +func (s *SerialConsistency) UnmarshalText(text []byte) error { + switch string(text) { + case "SERIAL": + *s = Serial + case "LOCAL_SERIAL": + *s = LocalSerial + default: + return fmt.Errorf("invalid consistency %q", string(text)) + } + + return nil +} + +const ( + ApacheCassandraTypePrefix = "org.apache.cassandra.db.marshal." +) + +const MaxFrameHeaderSize = 9 + +func ReadInt(p []byte) int32 { + return int32(p[0])<<24 | int32(p[1])<<16 | int32(p[2])<<8 | int32(p[3]) +} + +type FrameHeader struct { + Version ProtoVersion + Flags byte + Stream int + Op FrameOp + Length int + Warnings []string +} + +func (f FrameHeader) String() string { + return fmt.Sprintf("[header version=%s flags=0x%x stream=%d op=%s length=%d]", f.Version, f.Flags, f.Stream, f.Op, f.Length) +} + +func (f FrameHeader) Header() FrameHeader { + return f +} + +const defaultBufSize = 128 + +// a Framer is responsible for reading, writing and parsing frames on a single stream +type Framer struct { + proto byte + // flags are for outgoing flags, enabling compression and tracing etc + flags byte + compres compressor.Compressor + headSize int + // if this frame was read then the header will be here + Header *FrameHeader + + // if tracing flag is set this is not nil + TraceID []byte + + // holds a ref to the whole byte slice for buf so that it can be reset to + // 0 after a read. + readBuffer []byte + + Buf []byte + + CustomPayload map[string][]byte +} + +func NewFramer(compressor compressor.Compressor, version byte) *Framer { + buf := make([]byte, defaultBufSize) + f := &Framer{ + Buf: buf[:0], + readBuffer: buf, + } + var flags byte + if compressor != nil { + flags |= flagCompress + } + if version == ProtoVersion5 { + flags |= flagBetaProtocol + } + + version &= ProtoVersionMask + + headSize := 8 + if version > ProtoVersion2 { + headSize = 9 + } + + f.compres = compressor + f.proto = version + f.flags = flags + f.headSize = headSize + + f.Header = nil + f.TraceID = nil + + return f +} + +type Frame interface { + Header() FrameHeader +} + +func ReadHeader(r io.Reader, p []byte) (head FrameHeader, err error) { + _, err = io.ReadFull(r, p[:1]) + if err != nil { + return FrameHeader{}, err + } + + version := p[0] & ProtoVersionMask + + if version < ProtoVersion1 || version > ProtoVersion5 { + return FrameHeader{}, fmt.Errorf("gocql: unsupported protocol response version: %d", version) + } + + headSize := 9 + if version < ProtoVersion3 { + headSize = 8 + } + + _, err = io.ReadFull(r, p[1:headSize]) + if err != nil { + return FrameHeader{}, err + } + + p = p[:headSize] + + head.Version = ProtoVersion(p[0]) + head.Flags = p[1] + + if version > ProtoVersion2 { + if len(p) != 9 { + return FrameHeader{}, fmt.Errorf("not enough bytes to read header require 9 got: %d", len(p)) + } + + head.Stream = int(int16(p[2])<<8 | int16(p[3])) + head.Op = FrameOp(p[4]) + head.Length = int(ReadInt(p[5:])) + } else { + if len(p) != 8 { + return FrameHeader{}, fmt.Errorf("not enough bytes to read header require 8 got: %d", len(p)) + } + + head.Stream = int(int8(p[2])) + head.Op = FrameOp(p[3]) + head.Length = int(ReadInt(p[4:])) + } + + return head, nil +} + +// explicitly enables tracing for the Framers outgoing requests +func (f *Framer) Trace() { + f.flags |= flagTracing +} + +// explicitly enables the custom payload flag +func (f *Framer) payload() { + f.flags |= flagCustomPayload +} + +// reads a frame form the wire into the Framers buffer +func (f *Framer) ReadFrame(r io.Reader, head *FrameHeader) error { + if head.Length < 0 { + return fmt.Errorf("frame body length can not be less than 0: %d", head.Length) + } else if head.Length > MaxFrameSize { + // need to free up the connection to be used again + _, err := io.CopyN(ioutil.Discard, r, int64(head.Length)) + if err != nil { + return fmt.Errorf("error whilst trying to discard frame with invalid length: %v", err) + } + return ErrFrameTooBig + } + + if cap(f.readBuffer) >= head.Length { + f.Buf = f.readBuffer[:head.Length] + } else { + f.readBuffer = make([]byte, head.Length) + f.Buf = f.readBuffer + } + + // assume the underlying reader takes care of timeouts and retries + n, err := io.ReadFull(r, f.Buf) + if err != nil { + return fmt.Errorf("unable to read frame body: read %d/%d bytes: %v", n, head.Length, err) + } + + if head.Flags&flagCompress == flagCompress { + if f.compres == nil { + return NewErrProtocol("no compressor available with compressed frame body") + } + + f.Buf, err = f.compres.Decode(f.Buf) + if err != nil { + return err + } + } + + f.Header = head + return nil +} + +func (f *Framer) ParseFrame() (frame Frame, err error) { + defer func() { + if r := recover(); r != nil { + if _, ok := r.(runtime.Error); ok { + panic(r) + } + err = r.(error) + } + }() + + if f.Header.Version.request() { + return nil, NewErrProtocol("got a request frame from server: %v", f.Header.Version) + } + + if f.Header.Flags&flagTracing == flagTracing { + f.readTrace() + } + + if f.Header.Flags&flagWarning == flagWarning { + f.Header.Warnings = f.readStringList() + } + + if f.Header.Flags&flagCustomPayload == flagCustomPayload { + f.CustomPayload = f.readBytesMap() + } + + // assumes that the frame body has been read into rbuf + switch f.Header.Op { + case opError: + frame = f.parseErrorFrame() + case opReady: + frame = f.parseReadyFrame() + case opResult: + frame, err = f.parseResultFrame() + case opSupported: + frame = f.parseSupportedFrame() + case opAuthenticate: + frame = f.parseAuthenticateFrame() + case opAuthChallenge: + frame = f.parseAuthChallengeFrame() + case opAuthSuccess: + frame = f.parseAuthSuccessFrame() + case opEvent: + frame = f.parseEventFrame() + default: + return nil, NewErrProtocol("unknown op in frame header: %s", f.Header.Op) + } + + return +} + +func (f *Framer) parseErrorFrame() Frame { + code := f.readInt() + msg := f.readString() + + errD := internal_errors.ErrorFrame{ + FrameHeader: *f.Header, + Cod: code, + Messag: msg, + } + + switch code { + case internal_errors.ErrCodeUnavailable: + cl := f.readConsistency() + required := f.readInt() + alive := f.readInt() + return &internal_errors.RequestErrUnavailable{ + ErrorFrame: errD, + Consistency: cl, + Required: required, + Alive: alive, + } + case internal_errors.ErrCodeWriteTimeout: + cl := f.readConsistency() + received := f.readInt() + blockfor := f.readInt() + writeType := f.readString() + return &internal_errors.RequestErrWriteTimeout{ + ErrorFrame: errD, + Consistency: cl, + Received: received, + BlockFor: blockfor, + WriteType: writeType, + } + case internal_errors.ErrCodeReadTimeout: + cl := f.readConsistency() + received := f.readInt() + blockfor := f.readInt() + dataPresent := f.readByte() + return &internal_errors.RequestErrReadTimeout{ + ErrorFrame: errD, + Consistency: cl, + Received: received, + BlockFor: blockfor, + DataPresent: dataPresent, + } + case internal_errors.ErrCodeAlreadyExists: + ks := f.readString() + table := f.readString() + return &internal_errors.RequestErrAlreadyExists{ + ErrorFrame: errD, + Keyspace: ks, + Table: table, + } + case internal_errors.ErrCodeUnprepared: + stmtId := f.readShortBytes() + return &internal_errors.RequestErrUnprepared{ + ErrorFrame: errD, + StatementId: CopyBytes(stmtId), // defensively copy + } + case internal_errors.ErrCodeReadFailure: + res := &internal_errors.RequestErrReadFailure{ + ErrorFrame: errD, + } + res.Consistency = f.readConsistency() + res.Received = f.readInt() + res.BlockFor = f.readInt() + if f.proto > ProtoVersion4 { + res.ErrorMap = f.readErrorMap() + res.NumFailures = len(res.ErrorMap) + } else { + res.NumFailures = f.readInt() + } + res.DataPresent = f.readByte() != 0 + + return res + case internal_errors.ErrCodeWriteFailure: + res := &internal_errors.RequestErrWriteFailure{ + ErrorFrame: errD, + } + res.Consistency = f.readConsistency() + res.Received = f.readInt() + res.BlockFor = f.readInt() + if f.proto > ProtoVersion4 { + res.ErrorMap = f.readErrorMap() + res.NumFailures = len(res.ErrorMap) + } else { + res.NumFailures = f.readInt() + } + res.WriteType = f.readString() + return res + case internal_errors.ErrCodeFunctionFailure: + res := &internal_errors.RequestErrFunctionFailure{ + ErrorFrame: errD, + } + res.Keyspace = f.readString() + res.Function = f.readString() + res.ArgTypes = f.readStringList() + return res + + case internal_errors.ErrCodeCDCWriteFailure: + res := &internal_errors.RequestErrCDCWriteFailure{ + ErrorFrame: errD, + } + return res + case internal_errors.ErrCodeCASWriteUnknown: + res := &internal_errors.RequestErrCASWriteUnknown{ + ErrorFrame: errD, + } + res.Consistency = f.readConsistency() + res.Received = f.readInt() + res.BlockFor = f.readInt() + return res + case internal_errors.ErrCodeInvalid, internal_errors.ErrCodeBootstrapping, internal_errors.ErrCodeConfig, internal_errors.ErrCodeCredentials, internal_errors.ErrCodeOverloaded, + internal_errors.ErrCodeProtocol, internal_errors.ErrCodeServer, internal_errors.ErrCodeSyntax, internal_errors.ErrCodeTruncate, internal_errors.ErrCodeUnauthorized: + // TODO(zariel): we should have some distinct types for these internal_errors + return errD + default: + panic(fmt.Errorf("unknown error code: 0x%x", errD.Cod)) + } +} + +func (f *Framer) readErrorMap() (errMap internal_errors.ErrorMap) { + errMap = make(internal_errors.ErrorMap) + numErrs := f.readInt() + for i := 0; i < numErrs; i++ { + ip := f.readInetAdressOnly().String() + errMap[ip] = f.readShort() + } + return +} + +func (f *Framer) writeHeader(flags byte, op FrameOp, stream int) { + f.Buf = f.Buf[:0] + f.Buf = append(f.Buf, + f.proto, + flags, + ) + + if f.proto > ProtoVersion2 { + f.Buf = append(f.Buf, + byte(stream>>8), + byte(stream), + ) + } else { + f.Buf = append(f.Buf, + byte(stream), + ) + } + + // pad out length + f.Buf = append(f.Buf, + byte(op), + 0, + 0, + 0, + 0, + ) +} + +func (f *Framer) setLength(length int) { + p := 4 + if f.proto > ProtoVersion2 { + p = 5 + } + + f.Buf[p+0] = byte(length >> 24) + f.Buf[p+1] = byte(length >> 16) + f.Buf[p+2] = byte(length >> 8) + f.Buf[p+3] = byte(length) +} + +func (f *Framer) finish() error { + if len(f.Buf) > MaxFrameSize { + // huge app frame, lets remove it so it doesn't bloat the heap + f.Buf = make([]byte, defaultBufSize) + return ErrFrameTooBig + } + + if f.Buf[1]&flagCompress == flagCompress { + if f.compres == nil { + panic("compress flag set with no compressor") + } + + // TODO: only compress frames which are big enough + compressed, err := f.compres.Encode(f.Buf[f.headSize:]) + if err != nil { + return err + } + + f.Buf = append(f.Buf[:f.headSize], compressed...) + } + length := len(f.Buf) - f.headSize + f.setLength(length) + + return nil +} + +func (f *Framer) writeTo(w io.Writer) error { + _, err := w.Write(f.Buf) + return err +} + +func (f *Framer) readTrace() { + f.TraceID = f.readUUID().Bytes() +} + +type ReadyFrame struct { + FrameHeader +} + +func (f *Framer) parseReadyFrame() Frame { + return &ReadyFrame{ + FrameHeader: *f.Header, + } +} + +type SupportedFrame struct { + FrameHeader + + Supported map[string][]string +} + +// TODO: if we move the body buffer onto the FrameHeader then we only need a single +// Framer, and can move the methods onto the header. +func (f *Framer) parseSupportedFrame() Frame { + return &SupportedFrame{ + FrameHeader: *f.Header, + + Supported: f.readStringMultiMap(), + } +} + +type WriteStartupFrame struct { + Opts map[string]string +} + +func (w WriteStartupFrame) String() string { + return fmt.Sprintf("[startup opts=%+v]", w.Opts) +} + +func (w *WriteStartupFrame) BuildFrame(f *Framer, streamID int) error { + f.writeHeader(f.flags&^flagCompress, opStartup, streamID) + f.writeStringMap(w.Opts) + + return f.finish() +} + +type WritePrepareFrame struct { + Statement string + Keyspace string + customPayload map[string][]byte +} + +func (w *WritePrepareFrame) BuildFrame(f *Framer, streamID int) error { + if len(w.customPayload) > 0 { + f.payload() + } + f.writeHeader(f.flags, opPrepare, streamID) + f.writeCustomPayload(&w.customPayload) + f.writeLongString(w.Statement) + + var flags uint32 = 0 + if w.Keyspace != "" { + if f.proto > ProtoVersion4 { + flags |= flagWithPreparedKeyspace + } else { + panic(fmt.Errorf("the keyspace can only be set with protocol 5 or higher")) + } + } + if f.proto > ProtoVersion4 { + f.writeUint(flags) + } + if w.Keyspace != "" { + f.writeString(w.Keyspace) + } + + return f.finish() +} + +func (f *Framer) readTypeInfo() TypeInfo { + // TODO: factor this out so the same code paths can be used to parse custom + // types and other types, as much of the logic will be duplicated. + id := f.readShort() + + simple := NativeType{ + proto: f.proto, + Typ: Type(id), + } + + if simple.Typ == TypeCustom { + simple.Cust = f.readString() + if cassType := GetApacheCassandraType(simple.Cust); cassType != TypeCustom { + simple.Typ = cassType + } + } + + switch simple.Typ { + case TypeTuple: + n := f.readShort() + tuple := TupleTypeInfo{ + NativeType: simple, + Elems: make([]TypeInfo, n), + } + + for i := 0; i < int(n); i++ { + tuple.Elems[i] = f.readTypeInfo() + } + + return tuple + + case TypeUDT: + udt := UDTTypeInfo{ + NativeType: simple, + } + udt.KeySpace = f.readString() + udt.Name = f.readString() + + n := f.readShort() + udt.Elements = make([]UDTField, n) + for i := 0; i < int(n); i++ { + field := &udt.Elements[i] + field.Name = f.readString() + field.Type = f.readTypeInfo() + } + + return udt + case TypeMap, TypeList, TypeSet: + collection := CollectionType{ + NativeType: simple, + } + + if simple.Typ == TypeMap { + collection.Key = f.readTypeInfo() + } + + collection.Elem = f.readTypeInfo() + + return collection + } + + return simple +} + +type PreparedMetadata struct { + ResultMetadata + + // proto v4+ + PkeyColumns []int + + Keyspace string + + Table string +} + +func (r PreparedMetadata) String() string { + return fmt.Sprintf("[prepared flags=0x%x pkey=%v paging_state=% X columns=%v col_count=%d actual_col_count=%d]", r.Flags, r.PkeyColumns, r.PagingState, r.Columns, r.ColCount, r.ActualColCount) +} + +func (f *Framer) parsePreparedMetadata() PreparedMetadata { + // TODO: deduplicate this from parseMetadata + meta := PreparedMetadata{} + + meta.Flags = f.readInt() + meta.ColCount = f.readInt() + if meta.ColCount < 0 { + panic(fmt.Errorf("received negative column count: %d", meta.ColCount)) + } + meta.ActualColCount = meta.ColCount + + if f.proto >= ProtoVersion4 { + pkeyCount := f.readInt() + pkeys := make([]int, pkeyCount) + for i := 0; i < pkeyCount; i++ { + pkeys[i] = int(f.readShort()) + } + meta.PkeyColumns = pkeys + } + + if meta.Flags&flagHasMorePages == flagHasMorePages { + meta.PagingState = CopyBytes(f.readBytes()) + } + + if meta.Flags&flagNoMetaData == flagNoMetaData { + return meta + } + + globalSpec := meta.Flags&flagGlobalTableSpec == flagGlobalTableSpec + if globalSpec { + meta.Keyspace = f.readString() + meta.Table = f.readString() + } + + var cols []ColumnInfo + if meta.ColCount < 1000 { + // preallocate columninfo to avoid excess copying + cols = make([]ColumnInfo, meta.ColCount) + for i := 0; i < meta.ColCount; i++ { + f.readCol(&cols[i], &meta.ResultMetadata, globalSpec, meta.Keyspace, meta.Table) + } + } else { + // use append, huge number of columns usually indicates a corrupt frame or + // just a huge row. + for i := 0; i < meta.ColCount; i++ { + var col ColumnInfo + f.readCol(&col, &meta.ResultMetadata, globalSpec, meta.Keyspace, meta.Table) + cols = append(cols, col) + } + } + + meta.Columns = cols + + return meta +} + +type ResultMetadata struct { + Flags int + + // only if flagPageState + PagingState []byte + + Columns []ColumnInfo + ColCount int + + // this is a count of the total number of columns which can be scanned, + // it is at minimum len(columns) but may be larger, for instance when a column + // is a UDT or tuple. + ActualColCount int +} + +func (r *ResultMetadata) MorePages() bool { + return r.Flags&flagHasMorePages == flagHasMorePages +} + +func (r ResultMetadata) String() string { + return fmt.Sprintf("[metadata flags=0x%x paging_state=% X columns=%v]", r.Flags, r.PagingState, r.Columns) +} + +func (f *Framer) readCol(col *ColumnInfo, meta *ResultMetadata, globalSpec bool, keyspace, table string) { + if !globalSpec { + col.Keyspace = f.readString() + col.Table = f.readString() + } else { + col.Keyspace = keyspace + col.Table = table + } + + col.Name = f.readString() + col.TypeInfo = f.readTypeInfo() + switch v := col.TypeInfo.(type) { + // maybe also UDT + case TupleTypeInfo: + // -1 because we already included the tuple column + meta.ActualColCount += len(v.Elems) - 1 + } +} + +func (f *Framer) parseResultMetadata() ResultMetadata { + var meta ResultMetadata + + meta.Flags = f.readInt() + meta.ColCount = f.readInt() + if meta.ColCount < 0 { + panic(fmt.Errorf("received negative column count: %d", meta.ColCount)) + } + meta.ActualColCount = meta.ColCount + + if meta.Flags&flagHasMorePages == flagHasMorePages { + meta.PagingState = CopyBytes(f.readBytes()) + } + + if meta.Flags&flagNoMetaData == flagNoMetaData { + return meta + } + + var keyspace, table string + globalSpec := meta.Flags&flagGlobalTableSpec == flagGlobalTableSpec + if globalSpec { + keyspace = f.readString() + table = f.readString() + } + + var cols []ColumnInfo + if meta.ColCount < 1000 { + // preallocate columninfo to avoid excess copying + cols = make([]ColumnInfo, meta.ColCount) + for i := 0; i < meta.ColCount; i++ { + f.readCol(&cols[i], &meta, globalSpec, keyspace, table) + } + + } else { + // use append, huge number of columns usually indicates a corrupt frame or + // just a huge row. + for i := 0; i < meta.ColCount; i++ { + var col ColumnInfo + f.readCol(&col, &meta, globalSpec, keyspace, table) + cols = append(cols, col) + } + } + + meta.Columns = cols + + return meta +} + +type ResultVoidFrame struct { + FrameHeader +} + +func (f *ResultVoidFrame) String() string { + return "[result_void]" +} + +func (f *Framer) parseResultFrame() (Frame, error) { + kind := f.readInt() + + switch kind { + case resultKindVoid: + return &ResultVoidFrame{FrameHeader: *f.Header}, nil + case resultKindRows: + return f.parseResultRows(), nil + case resultKindKeyspace: + return f.parseResultSetKeyspace(), nil + case resultKindPrepared: + return f.parseResultPrepared(), nil + case resultKindSchemaChanged: + return f.parseResultSchemaChange(), nil + } + + return nil, NewErrProtocol("unknown result kind: %x", kind) +} + +type ResultRowsFrame struct { + FrameHeader + + Meta ResultMetadata + // dont parse the rows here as we only need to do it once + NumRows int +} + +func (f *ResultRowsFrame) String() string { + return fmt.Sprintf("[result_rows meta=%v]", f.Meta) +} + +func (f *Framer) parseResultRows() Frame { + result := &ResultRowsFrame{} + result.Meta = f.parseResultMetadata() + + result.NumRows = f.readInt() + if result.NumRows < 0 { + panic(fmt.Errorf("invalid row_count in result frame: %d", result.NumRows)) + } + + return result +} + +type ResultKeyspaceFrame struct { + FrameHeader + keyspace string +} + +func (r *ResultKeyspaceFrame) String() string { + return fmt.Sprintf("[result_keyspace keyspace=%s]", r.keyspace) +} + +func (f *Framer) parseResultSetKeyspace() Frame { + return &ResultKeyspaceFrame{ + FrameHeader: *f.Header, + keyspace: f.readString(), + } +} + +type ResultPreparedFrame struct { + FrameHeader + + PreparedID []byte + ReqMeta PreparedMetadata + RespMeta ResultMetadata +} + +func (f *Framer) parseResultPrepared() Frame { + frame := &ResultPreparedFrame{ + FrameHeader: *f.Header, + PreparedID: f.readShortBytes(), + ReqMeta: f.parsePreparedMetadata(), + } + + if f.proto < ProtoVersion2 { + return frame + } + + frame.RespMeta = f.parseResultMetadata() + + return frame +} + +type SchemaChangeKeyspace struct { + FrameHeader + + Change string + Keyspace string +} + +func (f SchemaChangeKeyspace) String() string { + return fmt.Sprintf("[event schema_change_keyspace change=%q keyspace=%q]", f.Change, f.Keyspace) +} + +type SchemaChangeTable struct { + FrameHeader + + change string + Keyspace string + object string +} + +func (f SchemaChangeTable) String() string { + return fmt.Sprintf("[event schema_change change=%q keyspace=%q object=%q]", f.change, f.Keyspace, f.object) +} + +type SchemaChangeType struct { + FrameHeader + + change string + Keyspace string + object string +} + +type SchemaChangeFunction struct { + FrameHeader + + change string + Keyspace string + name string + args []string +} + +type SchemaChangeAggregate struct { + FrameHeader + + change string + Keyspace string + name string + args []string +} + +func (f *Framer) parseResultSchemaChange() Frame { + if f.proto <= ProtoVersion2 { + change := f.readString() + keyspace := f.readString() + table := f.readString() + + if table != "" { + return &SchemaChangeTable{ + FrameHeader: *f.Header, + change: change, + Keyspace: keyspace, + object: table, + } + } else { + return &SchemaChangeKeyspace{ + FrameHeader: *f.Header, + Change: change, + Keyspace: keyspace, + } + } + } else { + change := f.readString() + target := f.readString() + + // TODO: could just use a separate type for each target + switch target { + case "KEYSPACE": + frame := &SchemaChangeKeyspace{ + FrameHeader: *f.Header, + Change: change, + } + + frame.Keyspace = f.readString() + + return frame + case "TABLE": + frame := &SchemaChangeTable{ + FrameHeader: *f.Header, + change: change, + } + + frame.Keyspace = f.readString() + frame.object = f.readString() + + return frame + case "TYPE": + frame := &SchemaChangeType{ + FrameHeader: *f.Header, + change: change, + } + + frame.Keyspace = f.readString() + frame.object = f.readString() + + return frame + case "FUNCTION": + frame := &SchemaChangeFunction{ + FrameHeader: *f.Header, + change: change, + } + + frame.Keyspace = f.readString() + frame.name = f.readString() + frame.args = f.readStringList() + + return frame + case "AGGREGATE": + frame := &SchemaChangeAggregate{ + FrameHeader: *f.Header, + change: change, + } + + frame.Keyspace = f.readString() + frame.name = f.readString() + frame.args = f.readStringList() + + return frame + default: + panic(fmt.Errorf("gocql: unknown SCHEMA_CHANGE target: %q change: %q", target, change)) + } + } + +} + +type AuthenticateFrame struct { + FrameHeader + + Class string +} + +func (a *AuthenticateFrame) String() string { + return fmt.Sprintf("[authenticate class=%q]", a.Class) +} + +func (f *Framer) parseAuthenticateFrame() Frame { + return &AuthenticateFrame{ + FrameHeader: *f.Header, + Class: f.readString(), + } +} + +type AuthSuccessFrame struct { + FrameHeader + + Data []byte +} + +func (a *AuthSuccessFrame) String() string { + return fmt.Sprintf("[auth_success data=%q]", a.Data) +} + +func (f *Framer) parseAuthSuccessFrame() Frame { + return &AuthSuccessFrame{ + FrameHeader: *f.Header, + Data: f.readBytes(), + } +} + +type AuthChallengeFrame struct { + FrameHeader + + Data []byte +} + +func (a *AuthChallengeFrame) String() string { + return fmt.Sprintf("[auth_challenge data=%q]", a.Data) +} + +func (f *Framer) parseAuthChallengeFrame() Frame { + return &AuthChallengeFrame{ + FrameHeader: *f.Header, + Data: f.readBytes(), + } +} + +type StatusChangeEventFrame struct { + FrameHeader + + Change string + Host net.IP + Port int +} + +func (t StatusChangeEventFrame) String() string { + return fmt.Sprintf("[status_change change=%s host=%v port=%v]", t.Change, t.Host, t.Port) +} + +// essentially the same as statusChange +type TopologyChangeEventFrame struct { + FrameHeader + + change string + host net.IP + port int +} + +func (t TopologyChangeEventFrame) String() string { + return fmt.Sprintf("[topology_change change=%s host=%v port=%v]", t.change, t.host, t.port) +} + +func (f *Framer) parseEventFrame() Frame { + eventType := f.readString() + + switch eventType { + case "TOPOLOGY_CHANGE": + frame := &TopologyChangeEventFrame{FrameHeader: *f.Header} + frame.change = f.readString() + frame.host, frame.port = f.readInet() + + return frame + case "STATUS_CHANGE": + frame := &StatusChangeEventFrame{FrameHeader: *f.Header} + frame.Change = f.readString() + frame.Host, frame.Port = f.readInet() + + return frame + case "SCHEMA_CHANGE": + // this should work for all versions + return f.parseResultSchemaChange() + default: + panic(fmt.Errorf("gocql: unknown event type: %q", eventType)) + } + +} + +type WriteAuthResponseFrame struct { + Data []byte +} + +func (a *WriteAuthResponseFrame) String() string { + return fmt.Sprintf("[auth_response data=%q]", a.Data) +} + +func (a *WriteAuthResponseFrame) BuildFrame(framer *Framer, streamID int) error { + return framer.WriteAuthResponseFrame(streamID, a.Data) +} + +func (f *Framer) WriteAuthResponseFrame(streamID int, data []byte) error { + f.writeHeader(f.flags, opAuthResponse, streamID) + f.writeBytes(data) + return f.finish() +} + +type QueryValues struct { + Value []byte + + // optional name, will set With names for values flag + Name string + IsUnset bool +} + +type QueryParams struct { + Consistency + // v2+ + SkipMeta bool + Values []QueryValues + PageSize int + PagingState []byte + SerialConsistency SerialConsistency + // v3+ + DefaultTimestamp bool + DefaultTimestampValue int64 + // v5+ + Keyspace string +} + +func (q QueryParams) String() string { + return fmt.Sprintf("[query_params consistency=%v skip_meta=%v page_size=%d paging_state=%q serial_consistency=%v default_timestamp=%v values=%v keyspace=%s]", + q.Consistency, q.SkipMeta, q.PageSize, q.PagingState, q.SerialConsistency, q.DefaultTimestamp, q.Values, q.Keyspace) +} + +func (f *Framer) writeQueryParams(opts *QueryParams) { + f.writeConsistency(opts.Consistency) + + if f.proto == ProtoVersion1 { + return + } + + var flags byte + if len(opts.Values) > 0 { + flags |= flagValues + } + if opts.SkipMeta { + flags |= flagSkipMetaData + } + if opts.PageSize > 0 { + flags |= flagPageSize + } + if len(opts.PagingState) > 0 { + flags |= flagWithPagingState + } + if opts.SerialConsistency > 0 { + flags |= flagWithSerialConsistency + } + + names := false + + // protoV3 specific things + if f.proto > ProtoVersion2 { + if opts.DefaultTimestamp { + flags |= flagDefaultTimestamp + } + + if len(opts.Values) > 0 && opts.Values[0].Name != "" { + flags |= flagWithNameValues + names = true + } + } + + if opts.Keyspace != "" { + if f.proto > ProtoVersion4 { + flags |= flagWithKeyspace + } else { + panic(fmt.Errorf("the keyspace can only be set with protocol 5 or higher")) + } + } + + if f.proto > ProtoVersion4 { + f.writeUint(uint32(flags)) + } else { + f.writeByte(flags) + } + + if n := len(opts.Values); n > 0 { + f.writeShort(uint16(n)) + + for i := 0; i < n; i++ { + if names { + f.writeString(opts.Values[i].Name) + } + if opts.Values[i].IsUnset { + f.writeUnset() + } else { + f.writeBytes(opts.Values[i].Value) + } + } + } + + if opts.PageSize > 0 { + f.writeInt(int32(opts.PageSize)) + } + + if len(opts.PagingState) > 0 { + f.writeBytes(opts.PagingState) + } + + if opts.SerialConsistency > 0 { + f.writeConsistency(Consistency(opts.SerialConsistency)) + } + + if f.proto > ProtoVersion2 && opts.DefaultTimestamp { + // timestamp in microseconds + var ts int64 + if opts.DefaultTimestampValue != 0 { + ts = opts.DefaultTimestampValue + } else { + ts = time.Now().UnixNano() / 1000 + } + f.writeLong(ts) + } + + if opts.Keyspace != "" { + f.writeString(opts.Keyspace) + } +} + +type WriteQueryFrame struct { + Statement string + Params QueryParams + + // v4+ + CustomPayload map[string][]byte +} + +func (w *WriteQueryFrame) String() string { + return fmt.Sprintf("[query statement=%q params=%v]", w.Statement, w.Params) +} + +func (w *WriteQueryFrame) BuildFrame(framer *Framer, streamID int) error { + return framer.WriteQueryFrame(streamID, w.Statement, &w.Params, w.CustomPayload) +} + +func (f *Framer) WriteQueryFrame(streamID int, statement string, params *QueryParams, customPayload map[string][]byte) error { + if len(customPayload) > 0 { + f.payload() + } + f.writeHeader(f.flags, opQuery, streamID) + f.writeCustomPayload(&customPayload) + f.writeLongString(statement) + f.writeQueryParams(params) + + return f.finish() +} + +type FrameBuilder interface { + BuildFrame(framer *Framer, streamID int) error +} + +type frameWriterFunc func(framer *Framer, streamID int) error + +func (f frameWriterFunc) BuildFrame(framer *Framer, streamID int) error { + return f(framer, streamID) +} + +type WriteExecuteFrame struct { + PreparedID []byte + Params QueryParams + + // v4+ + CustomPayload map[string][]byte +} + +func (e *WriteExecuteFrame) String() string { + return fmt.Sprintf("[execute id=% X params=%v]", e.PreparedID, &e.Params) +} + +func (e *WriteExecuteFrame) BuildFrame(fr *Framer, streamID int) error { + return fr.WriteExecuteFrame(streamID, e.PreparedID, &e.Params, &e.CustomPayload) +} + +func (f *Framer) WriteExecuteFrame(streamID int, preparedID []byte, params *QueryParams, customPayload *map[string][]byte) error { + if len(*customPayload) > 0 { + f.payload() + } + f.writeHeader(f.flags, opExecute, streamID) + f.writeCustomPayload(customPayload) + f.writeShortBytes(preparedID) + if f.proto > ProtoVersion1 { + f.writeQueryParams(params) + } else { + n := len(params.Values) + f.writeShort(uint16(n)) + for i := 0; i < n; i++ { + if params.Values[i].IsUnset { + f.writeUnset() + } else { + f.writeBytes(params.Values[i].Value) + } + } + f.writeConsistency(params.Consistency) + } + + return f.finish() +} + +// TODO: can we replace BatchStatemt with batchStatement? As they prety much +// duplicate each other +type BatchStatment struct { + PreparedID []byte + Statement string + Values []QueryValues +} + +type WriteBatchFrame struct { + Typ BatchType + Statements []BatchStatment + Consistency Consistency + + // v3+ + SerialConsistency SerialConsistency + DefaultTimestamp bool + DefaultTimestampValue int64 + + //v4+ + CustomPayload map[string][]byte +} + +func (w *WriteBatchFrame) BuildFrame(framer *Framer, streamID int) error { + return framer.WriteBatchFrame(streamID, w, w.CustomPayload) +} + +func (f *Framer) WriteBatchFrame(streamID int, w *WriteBatchFrame, customPayload map[string][]byte) error { + if len(customPayload) > 0 { + f.payload() + } + f.writeHeader(f.flags, opBatch, streamID) + f.writeCustomPayload(&customPayload) + f.writeByte(byte(w.Typ)) + + n := len(w.Statements) + f.writeShort(uint16(n)) + + var flags byte + + for i := 0; i < n; i++ { + b := &w.Statements[i] + if len(b.PreparedID) == 0 { + f.writeByte(0) + f.writeLongString(b.Statement) + } else { + f.writeByte(1) + f.writeShortBytes(b.PreparedID) + } + + f.writeShort(uint16(len(b.Values))) + for j := range b.Values { + col := b.Values[j] + if f.proto > ProtoVersion2 && col.Name != "" { + // TODO: move this check into the caller and set a flag on WriteBatchFrame + // to indicate using named values + if f.proto <= ProtoVersion5 { + return fmt.Errorf("gocql: named query values are not supported in batches, please see https://issues.apache.org/jira/browse/CASSANDRA-10246") + } + flags |= flagWithNameValues + f.writeString(col.Name) + } + if col.IsUnset { + f.writeUnset() + } else { + f.writeBytes(col.Value) + } + } + } + + f.writeConsistency(w.Consistency) + + if f.proto > ProtoVersion2 { + if w.SerialConsistency > 0 { + flags |= flagWithSerialConsistency + } + if w.DefaultTimestamp { + flags |= flagDefaultTimestamp + } + + if f.proto > ProtoVersion4 { + f.writeUint(uint32(flags)) + } else { + f.writeByte(flags) + } + + if w.SerialConsistency > 0 { + f.writeConsistency(Consistency(w.SerialConsistency)) + } + + if w.DefaultTimestamp { + var ts int64 + if w.DefaultTimestampValue != 0 { + ts = w.DefaultTimestampValue + } else { + ts = time.Now().UnixNano() / 1000 + } + f.writeLong(ts) + } + } + + return f.finish() +} + +type WriteOptionsFrame struct{} + +func (w *WriteOptionsFrame) BuildFrame(framer *Framer, streamID int) error { + return framer.WriteOptionsFrame(streamID, w) +} + +func (f *Framer) WriteOptionsFrame(stream int, _ *WriteOptionsFrame) error { + f.writeHeader(f.flags&^flagCompress, opOptions, stream) + return f.finish() +} + +type WriteRegisterFrame struct { + Events []string +} + +func (w *WriteRegisterFrame) BuildFrame(framer *Framer, streamID int) error { + return framer.WriteRegisterFrame(streamID, w) +} + +func (f *Framer) WriteRegisterFrame(streamID int, w *WriteRegisterFrame) error { + f.writeHeader(f.flags, opRegister, streamID) + f.writeStringList(w.Events) + + return f.finish() +} + +func (f *Framer) readByte() byte { + if len(f.Buf) < 1 { + panic(fmt.Errorf("not enough bytes in buffer to read byte require 1 got: %d", len(f.Buf))) + } + + b := f.Buf[0] + f.Buf = f.Buf[1:] + return b +} + +func (f *Framer) readInt() (n int) { + if len(f.Buf) < 4 { + panic(fmt.Errorf("not enough bytes in buffer to read int require 4 got: %d", len(f.Buf))) + } + + n = int(int32(f.Buf[0])<<24 | int32(f.Buf[1])<<16 | int32(f.Buf[2])<<8 | int32(f.Buf[3])) + f.Buf = f.Buf[4:] + return +} + +func (f *Framer) readShort() (n uint16) { + if len(f.Buf) < 2 { + panic(fmt.Errorf("not enough bytes in buffer to read short require 2 got: %d", len(f.Buf))) + } + n = uint16(f.Buf[0])<<8 | uint16(f.Buf[1]) + f.Buf = f.Buf[2:] + return +} + +func (f *Framer) readString() (s string) { + size := f.readShort() + + if len(f.Buf) < int(size) { + panic(fmt.Errorf("not enough bytes in buffer to read string require %d got: %d", size, len(f.Buf))) + } + + s = string(f.Buf[:size]) + f.Buf = f.Buf[size:] + return +} + +func (f *Framer) readLongString() (s string) { + size := f.readInt() + + if len(f.Buf) < size { + panic(fmt.Errorf("not enough bytes in buffer to read long string require %d got: %d", size, len(f.Buf))) + } + + s = string(f.Buf[:size]) + f.Buf = f.Buf[size:] + return +} + +func (f *Framer) readUUID() *UUID { + if len(f.Buf) < 16 { + panic(fmt.Errorf("not enough bytes in buffer to read uuid require %d got: %d", 16, len(f.Buf))) + } + + // TODO: how to handle this error, if it is a uuid, then sureley, problems? + u, _ := UUIDFromBytes(f.Buf[:16]) + f.Buf = f.Buf[16:] + return &u +} + +func (f *Framer) readStringList() []string { + size := f.readShort() + + l := make([]string, size) + for i := 0; i < int(size); i++ { + l[i] = f.readString() + } + + return l +} + +func (f *Framer) ReadBytesInternal() ([]byte, error) { + size := f.readInt() + if size < 0 { + return nil, nil + } + + if len(f.Buf) < size { + return nil, fmt.Errorf("not enough bytes in buffer to read bytes require %d got: %d", size, len(f.Buf)) + } + + l := f.Buf[:size] + f.Buf = f.Buf[size:] + + return l, nil +} + +func (f *Framer) readBytes() []byte { + l, err := f.ReadBytesInternal() + if err != nil { + panic(err) + } + + return l +} + +func (f *Framer) readShortBytes() []byte { + size := f.readShort() + if len(f.Buf) < int(size) { + panic(fmt.Errorf("not enough bytes in buffer to read short bytes: require %d got %d", size, len(f.Buf))) + } + + l := f.Buf[:size] + f.Buf = f.Buf[size:] + + return l +} + +func (f *Framer) readInetAdressOnly() net.IP { + if len(f.Buf) < 1 { + panic(fmt.Errorf("not enough bytes in buffer to read inet size require %d got: %d", 1, len(f.Buf))) + } + + size := f.Buf[0] + f.Buf = f.Buf[1:] + + if !(size == 4 || size == 16) { + panic(fmt.Errorf("invalid IP size: %d", size)) + } + + if len(f.Buf) < 1 { + panic(fmt.Errorf("not enough bytes in buffer to read inet require %d got: %d", size, len(f.Buf))) + } + + ip := make([]byte, size) + copy(ip, f.Buf[:size]) + f.Buf = f.Buf[size:] + return net.IP(ip) +} + +func (f *Framer) readInet() (net.IP, int) { + return f.readInetAdressOnly(), f.readInt() +} + +func (f *Framer) readConsistency() Consistency { + return Consistency(f.readShort()) +} + +func (f *Framer) readBytesMap() map[string][]byte { + size := f.readShort() + m := make(map[string][]byte, size) + + for i := 0; i < int(size); i++ { + k := f.readString() + v := f.readBytes() + m[k] = v + } + + return m +} + +func (f *Framer) readStringMultiMap() map[string][]string { + size := f.readShort() + m := make(map[string][]string, size) + + for i := 0; i < int(size); i++ { + k := f.readString() + v := f.readStringList() + m[k] = v + } + + return m +} + +func (f *Framer) writeByte(b byte) { + f.Buf = append(f.Buf, b) +} + +func appendBytes(p []byte, d []byte) []byte { + if d == nil { + return appendInt(p, -1) + } + p = appendInt(p, int32(len(d))) + p = append(p, d...) + return p +} + +func appendShort(p []byte, n uint16) []byte { + return append(p, + byte(n>>8), + byte(n), + ) +} + +func appendInt(p []byte, n int32) []byte { + return append(p, byte(n>>24), + byte(n>>16), + byte(n>>8), + byte(n)) +} + +func appendUint(p []byte, n uint32) []byte { + return append(p, byte(n>>24), + byte(n>>16), + byte(n>>8), + byte(n)) +} + +func appendLong(p []byte, n int64) []byte { + return append(p, + byte(n>>56), + byte(n>>48), + byte(n>>40), + byte(n>>32), + byte(n>>24), + byte(n>>16), + byte(n>>8), + byte(n), + ) +} + +func (f *Framer) writeCustomPayload(customPayload *map[string][]byte) { + if len(*customPayload) > 0 { + if f.proto < ProtoVersion4 { + panic("Custom payload is not supported with version V3 or less") + } + f.writeBytesMap(*customPayload) + } +} + +// these are protocol level binary types +func (f *Framer) writeInt(n int32) { + f.Buf = appendInt(f.Buf, n) +} + +func (f *Framer) writeUint(n uint32) { + f.Buf = appendUint(f.Buf, n) +} + +func (f *Framer) writeShort(n uint16) { + f.Buf = appendShort(f.Buf, n) +} + +func (f *Framer) writeLong(n int64) { + f.Buf = appendLong(f.Buf, n) +} + +func (f *Framer) writeString(s string) { + f.writeShort(uint16(len(s))) + f.Buf = append(f.Buf, s...) +} + +func (f *Framer) writeLongString(s string) { + f.writeInt(int32(len(s))) + f.Buf = append(f.Buf, s...) +} + +func (f *Framer) writeStringList(l []string) { + f.writeShort(uint16(len(l))) + for _, s := range l { + f.writeString(s) + } +} + +func (f *Framer) writeUnset() { + // Protocol version 4 specifies that bind variables do not require having a + // value when executing a statement. Bind variables without a value are + // called 'unset'. The 'unset' bind variable is serialized as the int + // value '-2' without following bytes. + f.writeInt(-2) +} + +func (f *Framer) writeBytes(p []byte) { + // TODO: handle null case correctly, + // [bytes] A [int] n, followed by n bytes if n >= 0. If n < 0, + // no byte should follow and the value represented is `null`. + if p == nil { + f.writeInt(-1) + } else { + f.writeInt(int32(len(p))) + f.Buf = append(f.Buf, p...) + } +} + +func (f *Framer) writeShortBytes(p []byte) { + f.writeShort(uint16(len(p))) + f.Buf = append(f.Buf, p...) +} + +func (f *Framer) writeConsistency(cons Consistency) { + f.writeShort(uint16(cons)) +} + +func (f *Framer) writeStringMap(m map[string]string) { + f.writeShort(uint16(len(m))) + for k, v := range m { + f.writeString(k) + f.writeString(v) + } +} + +func (f *Framer) writeBytesMap(m map[string][]byte) { + f.writeShort(uint16(len(m))) + for k, v := range m { + f.writeString(k) + f.writeBytes(v) + } +} + +type ErrProtocol struct{ error } + +func NewErrProtocol(format string, args ...interface{}) error { + return ErrProtocol{fmt.Errorf(format, args...)} +} diff --git a/internal/protocol/helpers.go b/internal/protocol/helpers.go new file mode 100644 index 000000000..6d70af0ab --- /dev/null +++ b/internal/protocol/helpers.go @@ -0,0 +1,275 @@ +package protocol + +import ( + "fmt" + "github.com/gocql/gocql/internal/logger" + "gopkg.in/inf.v0" + "math/big" + "reflect" + "strings" + "time" +) + +func goType(t TypeInfo) (reflect.Type, error) { + switch t.Type() { + case TypeVarchar, TypeAscii, TypeInet, TypeText: + return reflect.TypeOf(*new(string)), nil + case TypeBigInt, TypeCounter: + return reflect.TypeOf(*new(int64)), nil + case TypeTime: + return reflect.TypeOf(*new(time.Duration)), nil + case TypeTimestamp: + return reflect.TypeOf(*new(time.Time)), nil + case TypeBlob: + return reflect.TypeOf(*new([]byte)), nil + case TypeBoolean: + return reflect.TypeOf(*new(bool)), nil + case TypeFloat: + return reflect.TypeOf(*new(float32)), nil + case TypeDouble: + return reflect.TypeOf(*new(float64)), nil + case TypeInt: + return reflect.TypeOf(*new(int)), nil + case TypeSmallInt: + return reflect.TypeOf(*new(int16)), nil + case TypeTinyInt: + return reflect.TypeOf(*new(int8)), nil + case TypeDecimal: + return reflect.TypeOf(*new(*inf.Dec)), nil + case TypeUUID, TypeTimeUUID: + return reflect.TypeOf(*new(UUID)), nil + case TypeList, TypeSet: + elemType, err := goType(t.(CollectionType).Elem) + if err != nil { + return nil, err + } + return reflect.SliceOf(elemType), nil + case TypeMap: + keyType, err := goType(t.(CollectionType).Key) + if err != nil { + return nil, err + } + valueType, err := goType(t.(CollectionType).Elem) + if err != nil { + return nil, err + } + return reflect.MapOf(keyType, valueType), nil + case TypeVarint: + return reflect.TypeOf(*new(*big.Int)), nil + case TypeTuple: + // what can we do here? all there is to do is to make a list of interface{} + tuple := t.(TupleTypeInfo) + return reflect.TypeOf(make([]interface{}, len(tuple.Elems))), nil + case TypeUDT: + return reflect.TypeOf(make(map[string]interface{})), nil + case TypeDate: + return reflect.TypeOf(*new(time.Time)), nil + case TypeDuration: + return reflect.TypeOf(*new(Duration)), nil + default: + return nil, fmt.Errorf("cannot create Go type for unknown CQL type %s", t) + } +} + +func CopyBytes(p []byte) []byte { + b := make([]byte, len(p)) + copy(b, p) + return b +} + +func getCassandraBaseType(name string) Type { + switch name { + case "ascii": + return TypeAscii + case "bigint": + return TypeBigInt + case "blob": + return TypeBlob + case "boolean": + return TypeBoolean + case "counter": + return TypeCounter + case "date": + return TypeDate + case "decimal": + return TypeDecimal + case "double": + return TypeDouble + case "duration": + return TypeDuration + case "float": + return TypeFloat + case "int": + return TypeInt + case "smallint": + return TypeSmallInt + case "tinyint": + return TypeTinyInt + case "time": + return TypeTime + case "timestamp": + return TypeTimestamp + case "uuid": + return TypeUUID + case "varchar": + return TypeVarchar + case "text": + return TypeText + case "varint": + return TypeVarint + case "timeuuid": + return TypeTimeUUID + case "inet": + return TypeInet + case "MapType": + return TypeMap + case "ListType": + return TypeList + case "SetType": + return TypeSet + case "TupleType": + return TypeTuple + default: + return TypeCustom + } +} + +func GetCassandraType(name string, logger logger.StdLogger) TypeInfo { + if strings.HasPrefix(name, "frozen<") { + return GetCassandraType(strings.TrimPrefix(name[:len(name)-1], "frozen<"), logger) + } else if strings.HasPrefix(name, "set<") { + return CollectionType{ + NativeType: NativeType{Typ: TypeSet}, + Elem: GetCassandraType(strings.TrimPrefix(name[:len(name)-1], "set<"), logger), + } + } else if strings.HasPrefix(name, "list<") { + return CollectionType{ + NativeType: NativeType{Typ: TypeList}, + Elem: GetCassandraType(strings.TrimPrefix(name[:len(name)-1], "list<"), logger), + } + } else if strings.HasPrefix(name, "map<") { + names := splitCompositeTypes(strings.TrimPrefix(name[:len(name)-1], "map<")) + if len(names) != 2 { + logger.Printf("Error parsing map type, it has %d subelements, expecting 2\n", len(names)) + return NativeType{ + Typ: TypeCustom, + } + } + return CollectionType{ + NativeType: NativeType{Typ: TypeMap}, + Key: GetCassandraType(names[0], logger), + Elem: GetCassandraType(names[1], logger), + } + } else if strings.HasPrefix(name, "tuple<") { + names := splitCompositeTypes(strings.TrimPrefix(name[:len(name)-1], "tuple<")) + types := make([]TypeInfo, len(names)) + + for i, name := range names { + types[i] = GetCassandraType(name, logger) + } + + return TupleTypeInfo{ + NativeType: NativeType{Typ: TypeTuple}, + Elems: types, + } + } else { + return NativeType{ + Typ: getCassandraBaseType(name), + } + } +} + +func splitCompositeTypes(name string) []string { + if !strings.Contains(name, "<") { + return strings.Split(name, ", ") + } + var parts []string + lessCount := 0 + segment := "" + for _, char := range name { + if char == ',' && lessCount == 0 { + if segment != "" { + parts = append(parts, strings.TrimSpace(segment)) + } + segment = "" + continue + } + segment += string(char) + if char == '<' { + lessCount++ + } else if char == '>' { + lessCount-- + } + } + if segment != "" { + parts = append(parts, strings.TrimSpace(segment)) + } + return parts +} + +func ApacheToCassandraType(t string) string { + t = strings.Replace(t, ApacheCassandraTypePrefix, "", -1) + t = strings.Replace(t, "(", "<", -1) + t = strings.Replace(t, ")", ">", -1) + types := strings.FieldsFunc(t, func(r rune) bool { + return r == '<' || r == '>' || r == ',' + }) + for _, typ := range types { + t = strings.Replace(t, typ, GetApacheCassandraType(typ).String(), -1) + } + // This is done so it exactly matches what Cassandra returns + return strings.Replace(t, ",", ", ", -1) +} + +func GetApacheCassandraType(class string) Type { + switch strings.TrimPrefix(class, ApacheCassandraTypePrefix) { + case "AsciiType": + return TypeAscii + case "LongType": + return TypeBigInt + case "BytesType": + return TypeBlob + case "BooleanType": + return TypeBoolean + case "CounterColumnType": + return TypeCounter + case "DecimalType": + return TypeDecimal + case "DoubleType": + return TypeDouble + case "FloatType": + return TypeFloat + case "Int32Type": + return TypeInt + case "ShortType": + return TypeSmallInt + case "ByteType": + return TypeTinyInt + case "TimeType": + return TypeTime + case "DateType", "TimestampType": + return TypeTimestamp + case "UUIDType", "LexicalUUIDType": + return TypeUUID + case "UTF8Type": + return TypeVarchar + case "IntegerType": + return TypeVarint + case "TimeUUIDType": + return TypeTimeUUID + case "InetAddressType": + return TypeInet + case "MapType": + return TypeMap + case "ListType": + return TypeList + case "SetType": + return TypeSet + case "TupleType": + return TypeTuple + case "DurationType": + return TypeDuration + default: + return TypeCustom + } +} diff --git a/internal/protocol/session.go b/internal/protocol/session.go new file mode 100644 index 000000000..0dc2333e1 --- /dev/null +++ b/internal/protocol/session.go @@ -0,0 +1,16 @@ +package protocol + +import "fmt" + +type BatchType byte + +type ColumnInfo struct { + Keyspace string + Table string + Name string + TypeInfo TypeInfo +} + +func (c ColumnInfo) String() string { + return fmt.Sprintf("[column keyspace=%s table=%s name=%s type=%v]", c.Keyspace, c.Table, c.Name, c.TypeInfo) +} diff --git a/logger.go b/logger.go index 246a117d7..026b2c10d 100644 --- a/logger.go +++ b/logger.go @@ -25,40 +25,15 @@ package gocql import ( - "bytes" - "fmt" - "log" + "github.com/gocql/gocql/internal/logger" ) -type StdLogger interface { - Print(v ...interface{}) - Printf(format string, v ...interface{}) - Println(v ...interface{}) -} +type StdLogger = logger.StdLogger -type nopLogger struct{} +type nopLogger = logger.NopLogger -func (n nopLogger) Print(_ ...interface{}) {} +type testLogger = logger.TestLogger -func (n nopLogger) Printf(_ string, _ ...interface{}) {} +type defaultLogger = logger.DefaultLogger -func (n nopLogger) Println(_ ...interface{}) {} - -type testLogger struct { - capture bytes.Buffer -} - -func (l *testLogger) Print(v ...interface{}) { fmt.Fprint(&l.capture, v...) } -func (l *testLogger) Printf(format string, v ...interface{}) { fmt.Fprintf(&l.capture, format, v...) } -func (l *testLogger) Println(v ...interface{}) { fmt.Fprintln(&l.capture, v...) } -func (l *testLogger) String() string { return l.capture.String() } - -type defaultLogger struct{} - -func (l *defaultLogger) Print(v ...interface{}) { log.Print(v...) } -func (l *defaultLogger) Printf(format string, v ...interface{}) { log.Printf(format, v...) } -func (l *defaultLogger) Println(v ...interface{}) { log.Println(v...) } - -// Logger for logging messages. -// Deprecated: Use ClusterConfig.Logger instead. -var Logger StdLogger = &defaultLogger{} +var Logger = logger.Logger diff --git a/marshal.go b/marshal.go index 4d0adb923..ff0b46be6 100644 --- a/marshal.go +++ b/marshal.go @@ -25,42 +25,16 @@ package gocql import ( - "bytes" - "encoding/binary" - "errors" - "fmt" - "math" - "math/big" - "math/bits" - "net" - "reflect" - "strconv" - "strings" - "time" - - "gopkg.in/inf.v0" -) - -var ( - bigOne = big.NewInt(1) - emptyValue reflect.Value + "github.com/gocql/gocql/internal/protocol" ) var ( - ErrorUDTUnavailable = errors.New("UDT are not available on protocols less than 3, please update config") + ErrorUDTUnavailable = protocol.ErrorUDTUnavailable ) -// Marshaler is the interface implemented by objects that can marshal -// themselves into values understood by Cassandra. -type Marshaler interface { - MarshalCQL(info TypeInfo) ([]byte, error) -} +type Marshaler = protocol.Marshaler -// Unmarshaler is the interface implemented by objects that can unmarshal -// a Cassandra specific description of themselves. -type Unmarshaler interface { - UnmarshalCQL(info TypeInfo, data []byte) error -} +type Unmarshaler = protocol.Unmarshaler // Marshal returns the CQL encoding of the value for the Cassandra // internal type described by the info parameter. @@ -110,76 +84,7 @@ type Unmarshaler interface { // duration | time.Duration | // duration | gocql.Duration | // duration | string | parsed with time.ParseDuration -func Marshal(info TypeInfo, value interface{}) ([]byte, error) { - if info.Version() < protoVersion1 { - panic("protocol version not set") - } - - if valueRef := reflect.ValueOf(value); valueRef.Kind() == reflect.Ptr { - if valueRef.IsNil() { - return nil, nil - } else if v, ok := value.(Marshaler); ok { - return v.MarshalCQL(info) - } else { - return Marshal(info, valueRef.Elem().Interface()) - } - } - - if v, ok := value.(Marshaler); ok { - return v.MarshalCQL(info) - } - - switch info.Type() { - case TypeVarchar, TypeAscii, TypeBlob, TypeText: - return marshalVarchar(info, value) - case TypeBoolean: - return marshalBool(info, value) - case TypeTinyInt: - return marshalTinyInt(info, value) - case TypeSmallInt: - return marshalSmallInt(info, value) - case TypeInt: - return marshalInt(info, value) - case TypeBigInt, TypeCounter: - return marshalBigInt(info, value) - case TypeFloat: - return marshalFloat(info, value) - case TypeDouble: - return marshalDouble(info, value) - case TypeDecimal: - return marshalDecimal(info, value) - case TypeTime: - return marshalTime(info, value) - case TypeTimestamp: - return marshalTimestamp(info, value) - case TypeList, TypeSet: - return marshalList(info, value) - case TypeMap: - return marshalMap(info, value) - case TypeUUID, TypeTimeUUID: - return marshalUUID(info, value) - case TypeVarint: - return marshalVarint(info, value) - case TypeInet: - return marshalInet(info, value) - case TypeTuple: - return marshalTuple(info, value) - case TypeUDT: - return marshalUDT(info, value) - case TypeDate: - return marshalDate(info, value) - case TypeDuration: - return marshalDuration(info, value) - } - - // detect protocol 2 UDT - if strings.HasPrefix(info.Custom(), "org.apache.cassandra.db.marshal.UserType") && info.Version() < 3 { - return nil, ErrorUDTUnavailable - } - - // TODO(tux21b): add the remaining types - return nil, fmt.Errorf("can not marshal %T into %s", value, info) -} +var Marshal = protocol.Marshal // Unmarshal parses the CQL encoded data based on the info parameter that // describes the Cassandra internal data type and stores the result in the @@ -222,2526 +127,56 @@ func Marshal(info TypeInfo, value interface{}) ([]byte, error) { // date | *time.Time | time of beginning of the day (in UTC) // date | *string | formatted with 2006-01-02 format // duration | *gocql.Duration | -func Unmarshal(info TypeInfo, data []byte, value interface{}) error { - if v, ok := value.(Unmarshaler); ok { - return v.UnmarshalCQL(info, data) - } - - if isNullableValue(value) { - return unmarshalNullable(info, data, value) - } - - switch info.Type() { - case TypeVarchar, TypeAscii, TypeBlob, TypeText: - return unmarshalVarchar(info, data, value) - case TypeBoolean: - return unmarshalBool(info, data, value) - case TypeInt: - return unmarshalInt(info, data, value) - case TypeBigInt, TypeCounter: - return unmarshalBigInt(info, data, value) - case TypeVarint: - return unmarshalVarint(info, data, value) - case TypeSmallInt: - return unmarshalSmallInt(info, data, value) - case TypeTinyInt: - return unmarshalTinyInt(info, data, value) - case TypeFloat: - return unmarshalFloat(info, data, value) - case TypeDouble: - return unmarshalDouble(info, data, value) - case TypeDecimal: - return unmarshalDecimal(info, data, value) - case TypeTime: - return unmarshalTime(info, data, value) - case TypeTimestamp: - return unmarshalTimestamp(info, data, value) - case TypeList, TypeSet: - return unmarshalList(info, data, value) - case TypeMap: - return unmarshalMap(info, data, value) - case TypeTimeUUID: - return unmarshalTimeUUID(info, data, value) - case TypeUUID: - return unmarshalUUID(info, data, value) - case TypeInet: - return unmarshalInet(info, data, value) - case TypeTuple: - return unmarshalTuple(info, data, value) - case TypeUDT: - return unmarshalUDT(info, data, value) - case TypeDate: - return unmarshalDate(info, data, value) - case TypeDuration: - return unmarshalDuration(info, data, value) - } - - // detect protocol 2 UDT - if strings.HasPrefix(info.Custom(), "org.apache.cassandra.db.marshal.UserType") && info.Version() < 3 { - return ErrorUDTUnavailable - } - - // TODO(tux21b): add the remaining types - return fmt.Errorf("can not unmarshal %s into %T", info, value) -} - -func isNullableValue(value interface{}) bool { - v := reflect.ValueOf(value) - return v.Kind() == reflect.Ptr && v.Type().Elem().Kind() == reflect.Ptr -} - -func isNullData(info TypeInfo, data []byte) bool { - return data == nil -} - -func unmarshalNullable(info TypeInfo, data []byte, value interface{}) error { - valueRef := reflect.ValueOf(value) - - if isNullData(info, data) { - nilValue := reflect.Zero(valueRef.Type().Elem()) - valueRef.Elem().Set(nilValue) - return nil - } - - newValue := reflect.New(valueRef.Type().Elem().Elem()) - valueRef.Elem().Set(newValue) - return Unmarshal(info, data, newValue.Interface()) -} - -func marshalVarchar(info TypeInfo, value interface{}) ([]byte, error) { - switch v := value.(type) { - case Marshaler: - return v.MarshalCQL(info) - case unsetColumn: - return nil, nil - case string: - return []byte(v), nil - case []byte: - return v, nil - } - - if value == nil { - return nil, nil - } - - rv := reflect.ValueOf(value) - t := rv.Type() - k := t.Kind() - switch { - case k == reflect.String: - return []byte(rv.String()), nil - case k == reflect.Slice && t.Elem().Kind() == reflect.Uint8: - return rv.Bytes(), nil - } - return nil, marshalErrorf("can not marshal %T into %s", value, info) -} - -func unmarshalVarchar(info TypeInfo, data []byte, value interface{}) error { - switch v := value.(type) { - case Unmarshaler: - return v.UnmarshalCQL(info, data) - case *string: - *v = string(data) - return nil - case *[]byte: - if data != nil { - *v = append((*v)[:0], data...) - } else { - *v = nil - } - return nil - } - - rv := reflect.ValueOf(value) - if rv.Kind() != reflect.Ptr { - return unmarshalErrorf("can not unmarshal into non-pointer %T", value) - } - rv = rv.Elem() - t := rv.Type() - k := t.Kind() - switch { - case k == reflect.String: - rv.SetString(string(data)) - return nil - case k == reflect.Slice && t.Elem().Kind() == reflect.Uint8: - var dataCopy []byte - if data != nil { - dataCopy = make([]byte, len(data)) - copy(dataCopy, data) - } - rv.SetBytes(dataCopy) - return nil - } - return unmarshalErrorf("can not unmarshal %s into %T", info, value) -} - -func marshalSmallInt(info TypeInfo, value interface{}) ([]byte, error) { - switch v := value.(type) { - case Marshaler: - return v.MarshalCQL(info) - case unsetColumn: - return nil, nil - case int16: - return encShort(v), nil - case uint16: - return encShort(int16(v)), nil - case int8: - return encShort(int16(v)), nil - case uint8: - return encShort(int16(v)), nil - case int: - if v > math.MaxInt16 || v < math.MinInt16 { - return nil, marshalErrorf("marshal smallint: value %d out of range", v) - } - return encShort(int16(v)), nil - case int32: - if v > math.MaxInt16 || v < math.MinInt16 { - return nil, marshalErrorf("marshal smallint: value %d out of range", v) - } - return encShort(int16(v)), nil - case int64: - if v > math.MaxInt16 || v < math.MinInt16 { - return nil, marshalErrorf("marshal smallint: value %d out of range", v) - } - return encShort(int16(v)), nil - case uint: - if v > math.MaxUint16 { - return nil, marshalErrorf("marshal smallint: value %d out of range", v) - } - return encShort(int16(v)), nil - case uint32: - if v > math.MaxUint16 { - return nil, marshalErrorf("marshal smallint: value %d out of range", v) - } - return encShort(int16(v)), nil - case uint64: - if v > math.MaxUint16 { - return nil, marshalErrorf("marshal smallint: value %d out of range", v) - } - return encShort(int16(v)), nil - case string: - n, err := strconv.ParseInt(v, 10, 16) - if err != nil { - return nil, marshalErrorf("can not marshal %T into %s: %v", value, info, err) - } - return encShort(int16(n)), nil - } - - if value == nil { - return nil, nil - } - - switch rv := reflect.ValueOf(value); rv.Type().Kind() { - case reflect.Int, reflect.Int64, reflect.Int32, reflect.Int16, reflect.Int8: - v := rv.Int() - if v > math.MaxInt16 || v < math.MinInt16 { - return nil, marshalErrorf("marshal smallint: value %d out of range", v) - } - return encShort(int16(v)), nil - case reflect.Uint, reflect.Uint64, reflect.Uint32, reflect.Uint16, reflect.Uint8: - v := rv.Uint() - if v > math.MaxUint16 { - return nil, marshalErrorf("marshal smallint: value %d out of range", v) - } - return encShort(int16(v)), nil - case reflect.Ptr: - if rv.IsNil() { - return nil, nil - } - } - - return nil, marshalErrorf("can not marshal %T into %s", value, info) -} - -func marshalTinyInt(info TypeInfo, value interface{}) ([]byte, error) { - switch v := value.(type) { - case Marshaler: - return v.MarshalCQL(info) - case unsetColumn: - return nil, nil - case int8: - return []byte{byte(v)}, nil - case uint8: - return []byte{byte(v)}, nil - case int16: - if v > math.MaxInt8 || v < math.MinInt8 { - return nil, marshalErrorf("marshal tinyint: value %d out of range", v) - } - return []byte{byte(v)}, nil - case uint16: - if v > math.MaxUint8 { - return nil, marshalErrorf("marshal tinyint: value %d out of range", v) - } - return []byte{byte(v)}, nil - case int: - if v > math.MaxInt8 || v < math.MinInt8 { - return nil, marshalErrorf("marshal tinyint: value %d out of range", v) - } - return []byte{byte(v)}, nil - case int32: - if v > math.MaxInt8 || v < math.MinInt8 { - return nil, marshalErrorf("marshal tinyint: value %d out of range", v) - } - return []byte{byte(v)}, nil - case int64: - if v > math.MaxInt8 || v < math.MinInt8 { - return nil, marshalErrorf("marshal tinyint: value %d out of range", v) - } - return []byte{byte(v)}, nil - case uint: - if v > math.MaxUint8 { - return nil, marshalErrorf("marshal tinyint: value %d out of range", v) - } - return []byte{byte(v)}, nil - case uint32: - if v > math.MaxUint8 { - return nil, marshalErrorf("marshal tinyint: value %d out of range", v) - } - return []byte{byte(v)}, nil - case uint64: - if v > math.MaxUint8 { - return nil, marshalErrorf("marshal tinyint: value %d out of range", v) - } - return []byte{byte(v)}, nil - case string: - n, err := strconv.ParseInt(v, 10, 8) - if err != nil { - return nil, marshalErrorf("can not marshal %T into %s: %v", value, info, err) - } - return []byte{byte(n)}, nil - } - - if value == nil { - return nil, nil - } - - switch rv := reflect.ValueOf(value); rv.Type().Kind() { - case reflect.Int, reflect.Int64, reflect.Int32, reflect.Int16, reflect.Int8: - v := rv.Int() - if v > math.MaxInt8 || v < math.MinInt8 { - return nil, marshalErrorf("marshal tinyint: value %d out of range", v) - } - return []byte{byte(v)}, nil - case reflect.Uint, reflect.Uint64, reflect.Uint32, reflect.Uint16, reflect.Uint8: - v := rv.Uint() - if v > math.MaxUint8 { - return nil, marshalErrorf("marshal tinyint: value %d out of range", v) - } - return []byte{byte(v)}, nil - case reflect.Ptr: - if rv.IsNil() { - return nil, nil - } - } - - return nil, marshalErrorf("can not marshal %T into %s", value, info) -} - -func marshalInt(info TypeInfo, value interface{}) ([]byte, error) { - switch v := value.(type) { - case Marshaler: - return v.MarshalCQL(info) - case unsetColumn: - return nil, nil - case int: - if v > math.MaxInt32 || v < math.MinInt32 { - return nil, marshalErrorf("marshal int: value %d out of range", v) - } - return encInt(int32(v)), nil - case uint: - if v > math.MaxUint32 { - return nil, marshalErrorf("marshal int: value %d out of range", v) - } - return encInt(int32(v)), nil - case int64: - if v > math.MaxInt32 || v < math.MinInt32 { - return nil, marshalErrorf("marshal int: value %d out of range", v) - } - return encInt(int32(v)), nil - case uint64: - if v > math.MaxUint32 { - return nil, marshalErrorf("marshal int: value %d out of range", v) - } - return encInt(int32(v)), nil - case int32: - return encInt(v), nil - case uint32: - return encInt(int32(v)), nil - case int16: - return encInt(int32(v)), nil - case uint16: - return encInt(int32(v)), nil - case int8: - return encInt(int32(v)), nil - case uint8: - return encInt(int32(v)), nil - case string: - i, err := strconv.ParseInt(v, 10, 32) - if err != nil { - return nil, marshalErrorf("can not marshal string to int: %s", err) - } - return encInt(int32(i)), nil - } - - if value == nil { - return nil, nil - } - - switch rv := reflect.ValueOf(value); rv.Type().Kind() { - case reflect.Int, reflect.Int64, reflect.Int32, reflect.Int16, reflect.Int8: - v := rv.Int() - if v > math.MaxInt32 || v < math.MinInt32 { - return nil, marshalErrorf("marshal int: value %d out of range", v) - } - return encInt(int32(v)), nil - case reflect.Uint, reflect.Uint64, reflect.Uint32, reflect.Uint16, reflect.Uint8: - v := rv.Uint() - if v > math.MaxInt32 { - return nil, marshalErrorf("marshal int: value %d out of range", v) - } - return encInt(int32(v)), nil - case reflect.Ptr: - if rv.IsNil() { - return nil, nil - } - } - - return nil, marshalErrorf("can not marshal %T into %s", value, info) -} - -func encInt(x int32) []byte { - return []byte{byte(x >> 24), byte(x >> 16), byte(x >> 8), byte(x)} -} - -func decInt(x []byte) int32 { - if len(x) != 4 { - return 0 - } - return int32(x[0])<<24 | int32(x[1])<<16 | int32(x[2])<<8 | int32(x[3]) -} - -func encShort(x int16) []byte { - p := make([]byte, 2) - p[0] = byte(x >> 8) - p[1] = byte(x) - return p -} - -func decShort(p []byte) int16 { - if len(p) != 2 { - return 0 - } - return int16(p[0])<<8 | int16(p[1]) -} - -func decTiny(p []byte) int8 { - if len(p) != 1 { - return 0 - } - return int8(p[0]) -} - -func marshalBigInt(info TypeInfo, value interface{}) ([]byte, error) { - switch v := value.(type) { - case Marshaler: - return v.MarshalCQL(info) - case unsetColumn: - return nil, nil - case int: - return encBigInt(int64(v)), nil - case uint: - if uint64(v) > math.MaxInt64 { - return nil, marshalErrorf("marshal bigint: value %d out of range", v) - } - return encBigInt(int64(v)), nil - case int64: - return encBigInt(v), nil - case uint64: - return encBigInt(int64(v)), nil - case int32: - return encBigInt(int64(v)), nil - case uint32: - return encBigInt(int64(v)), nil - case int16: - return encBigInt(int64(v)), nil - case uint16: - return encBigInt(int64(v)), nil - case int8: - return encBigInt(int64(v)), nil - case uint8: - return encBigInt(int64(v)), nil - case big.Int: - return encBigInt2C(&v), nil - case string: - i, err := strconv.ParseInt(value.(string), 10, 64) - if err != nil { - return nil, marshalErrorf("can not marshal string to bigint: %s", err) - } - return encBigInt(i), nil - } - - if value == nil { - return nil, nil - } - - rv := reflect.ValueOf(value) - switch rv.Type().Kind() { - case reflect.Int, reflect.Int64, reflect.Int32, reflect.Int16, reflect.Int8: - v := rv.Int() - return encBigInt(v), nil - case reflect.Uint, reflect.Uint64, reflect.Uint32, reflect.Uint16, reflect.Uint8: - v := rv.Uint() - if v > math.MaxInt64 { - return nil, marshalErrorf("marshal bigint: value %d out of range", v) - } - return encBigInt(int64(v)), nil - } - return nil, marshalErrorf("can not marshal %T into %s", value, info) -} - -func encBigInt(x int64) []byte { - return []byte{byte(x >> 56), byte(x >> 48), byte(x >> 40), byte(x >> 32), - byte(x >> 24), byte(x >> 16), byte(x >> 8), byte(x)} -} - -func bytesToInt64(data []byte) (ret int64) { - for i := range data { - ret |= int64(data[i]) << (8 * uint(len(data)-i-1)) - } - return ret -} - -func bytesToUint64(data []byte) (ret uint64) { - for i := range data { - ret |= uint64(data[i]) << (8 * uint(len(data)-i-1)) - } - return ret -} - -func unmarshalBigInt(info TypeInfo, data []byte, value interface{}) error { - return unmarshalIntlike(info, decBigInt(data), data, value) -} - -func unmarshalInt(info TypeInfo, data []byte, value interface{}) error { - return unmarshalIntlike(info, int64(decInt(data)), data, value) -} - -func unmarshalSmallInt(info TypeInfo, data []byte, value interface{}) error { - return unmarshalIntlike(info, int64(decShort(data)), data, value) -} - -func unmarshalTinyInt(info TypeInfo, data []byte, value interface{}) error { - return unmarshalIntlike(info, int64(decTiny(data)), data, value) -} - -func unmarshalVarint(info TypeInfo, data []byte, value interface{}) error { - switch v := value.(type) { - case *big.Int: - return unmarshalIntlike(info, 0, data, value) - case *uint64: - if len(data) == 9 && data[0] == 0 { - *v = bytesToUint64(data[1:]) - return nil - } - } - - if len(data) > 8 { - return unmarshalErrorf("unmarshal int: varint value %v out of range for %T (use big.Int)", data, value) - } - - int64Val := bytesToInt64(data) - if len(data) > 0 && len(data) < 8 && data[0]&0x80 > 0 { - int64Val -= (1 << uint(len(data)*8)) - } - return unmarshalIntlike(info, int64Val, data, value) -} - -func marshalVarint(info TypeInfo, value interface{}) ([]byte, error) { - var ( - retBytes []byte - err error - ) - - switch v := value.(type) { - case unsetColumn: - return nil, nil - case uint64: - if v > uint64(math.MaxInt64) { - retBytes = make([]byte, 9) - binary.BigEndian.PutUint64(retBytes[1:], v) - } else { - retBytes = make([]byte, 8) - binary.BigEndian.PutUint64(retBytes, v) - } - default: - retBytes, err = marshalBigInt(info, value) - } - - if err == nil { - // trim down to most significant byte - i := 0 - for ; i < len(retBytes)-1; i++ { - b0 := retBytes[i] - if b0 != 0 && b0 != 0xFF { - break - } - - b1 := retBytes[i+1] - if b0 == 0 && b1 != 0 { - if b1&0x80 == 0 { - i++ - } - break - } - - if b0 == 0xFF && b1 != 0xFF { - if b1&0x80 > 0 { - i++ - } - break - } - } - retBytes = retBytes[i:] - } - - return retBytes, err -} - -func unmarshalIntlike(info TypeInfo, int64Val int64, data []byte, value interface{}) error { - switch v := value.(type) { - case *int: - if ^uint(0) == math.MaxUint32 && (int64Val < math.MinInt32 || int64Val > math.MaxInt32) { - return unmarshalErrorf("unmarshal int: value %d out of range for %T", int64Val, *v) - } - *v = int(int64Val) - return nil - case *uint: - unitVal := uint64(int64Val) - switch info.Type() { - case TypeInt: - *v = uint(unitVal) & 0xFFFFFFFF - case TypeSmallInt: - *v = uint(unitVal) & 0xFFFF - case TypeTinyInt: - *v = uint(unitVal) & 0xFF - default: - if ^uint(0) == math.MaxUint32 && (int64Val < 0 || int64Val > math.MaxUint32) { - return unmarshalErrorf("unmarshal int: value %d out of range for %T", unitVal, *v) - } - *v = uint(unitVal) - } - return nil - case *int64: - *v = int64Val - return nil - case *uint64: - switch info.Type() { - case TypeInt: - *v = uint64(int64Val) & 0xFFFFFFFF - case TypeSmallInt: - *v = uint64(int64Val) & 0xFFFF - case TypeTinyInt: - *v = uint64(int64Val) & 0xFF - default: - *v = uint64(int64Val) - } - return nil - case *int32: - if int64Val < math.MinInt32 || int64Val > math.MaxInt32 { - return unmarshalErrorf("unmarshal int: value %d out of range for %T", int64Val, *v) - } - *v = int32(int64Val) - return nil - case *uint32: - switch info.Type() { - case TypeInt: - *v = uint32(int64Val) & 0xFFFFFFFF - case TypeSmallInt: - *v = uint32(int64Val) & 0xFFFF - case TypeTinyInt: - *v = uint32(int64Val) & 0xFF - default: - if int64Val < 0 || int64Val > math.MaxUint32 { - return unmarshalErrorf("unmarshal int: value %d out of range for %T", int64Val, *v) - } - *v = uint32(int64Val) & 0xFFFFFFFF - } - return nil - case *int16: - if int64Val < math.MinInt16 || int64Val > math.MaxInt16 { - return unmarshalErrorf("unmarshal int: value %d out of range for %T", int64Val, *v) - } - *v = int16(int64Val) - return nil - case *uint16: - switch info.Type() { - case TypeSmallInt: - *v = uint16(int64Val) & 0xFFFF - case TypeTinyInt: - *v = uint16(int64Val) & 0xFF - default: - if int64Val < 0 || int64Val > math.MaxUint16 { - return unmarshalErrorf("unmarshal int: value %d out of range for %T", int64Val, *v) - } - *v = uint16(int64Val) & 0xFFFF - } - return nil - case *int8: - if int64Val < math.MinInt8 || int64Val > math.MaxInt8 { - return unmarshalErrorf("unmarshal int: value %d out of range for %T", int64Val, *v) - } - *v = int8(int64Val) - return nil - case *uint8: - if info.Type() != TypeTinyInt && (int64Val < 0 || int64Val > math.MaxUint8) { - return unmarshalErrorf("unmarshal int: value %d out of range for %T", int64Val, *v) - } - *v = uint8(int64Val) & 0xFF - return nil - case *big.Int: - decBigInt2C(data, v) - return nil - case *string: - *v = strconv.FormatInt(int64Val, 10) - return nil - } - - rv := reflect.ValueOf(value) - if rv.Kind() != reflect.Ptr { - return unmarshalErrorf("can not unmarshal into non-pointer %T", value) - } - rv = rv.Elem() - - switch rv.Type().Kind() { - case reflect.Int: - if ^uint(0) == math.MaxUint32 && (int64Val < math.MinInt32 || int64Val > math.MaxInt32) { - return unmarshalErrorf("unmarshal int: value %d out of range", int64Val) - } - rv.SetInt(int64Val) - return nil - case reflect.Int64: - rv.SetInt(int64Val) - return nil - case reflect.Int32: - if int64Val < math.MinInt32 || int64Val > math.MaxInt32 { - return unmarshalErrorf("unmarshal int: value %d out of range", int64Val) - } - rv.SetInt(int64Val) - return nil - case reflect.Int16: - if int64Val < math.MinInt16 || int64Val > math.MaxInt16 { - return unmarshalErrorf("unmarshal int: value %d out of range", int64Val) - } - rv.SetInt(int64Val) - return nil - case reflect.Int8: - if int64Val < math.MinInt8 || int64Val > math.MaxInt8 { - return unmarshalErrorf("unmarshal int: value %d out of range", int64Val) - } - rv.SetInt(int64Val) - return nil - case reflect.Uint: - unitVal := uint64(int64Val) - switch info.Type() { - case TypeInt: - rv.SetUint(unitVal & 0xFFFFFFFF) - case TypeSmallInt: - rv.SetUint(unitVal & 0xFFFF) - case TypeTinyInt: - rv.SetUint(unitVal & 0xFF) - default: - if ^uint(0) == math.MaxUint32 && (int64Val < 0 || int64Val > math.MaxUint32) { - return unmarshalErrorf("unmarshal int: value %d out of range for %s", unitVal, rv.Type()) - } - rv.SetUint(unitVal) - } - return nil - case reflect.Uint64: - unitVal := uint64(int64Val) - switch info.Type() { - case TypeInt: - rv.SetUint(unitVal & 0xFFFFFFFF) - case TypeSmallInt: - rv.SetUint(unitVal & 0xFFFF) - case TypeTinyInt: - rv.SetUint(unitVal & 0xFF) - default: - rv.SetUint(unitVal) - } - return nil - case reflect.Uint32: - unitVal := uint64(int64Val) - switch info.Type() { - case TypeInt: - rv.SetUint(unitVal & 0xFFFFFFFF) - case TypeSmallInt: - rv.SetUint(unitVal & 0xFFFF) - case TypeTinyInt: - rv.SetUint(unitVal & 0xFF) - default: - if int64Val < 0 || int64Val > math.MaxUint32 { - return unmarshalErrorf("unmarshal int: value %d out of range for %s", int64Val, rv.Type()) - } - rv.SetUint(unitVal & 0xFFFFFFFF) - } - return nil - case reflect.Uint16: - unitVal := uint64(int64Val) - switch info.Type() { - case TypeSmallInt: - rv.SetUint(unitVal & 0xFFFF) - case TypeTinyInt: - rv.SetUint(unitVal & 0xFF) - default: - if int64Val < 0 || int64Val > math.MaxUint16 { - return unmarshalErrorf("unmarshal int: value %d out of range for %s", int64Val, rv.Type()) - } - rv.SetUint(unitVal & 0xFFFF) - } - return nil - case reflect.Uint8: - if info.Type() != TypeTinyInt && (int64Val < 0 || int64Val > math.MaxUint8) { - return unmarshalErrorf("unmarshal int: value %d out of range for %s", int64Val, rv.Type()) - } - rv.SetUint(uint64(int64Val) & 0xff) - return nil - } - return unmarshalErrorf("can not unmarshal %s into %T", info, value) -} - -func decBigInt(data []byte) int64 { - if len(data) != 8 { - return 0 - } - return int64(data[0])<<56 | int64(data[1])<<48 | - int64(data[2])<<40 | int64(data[3])<<32 | - int64(data[4])<<24 | int64(data[5])<<16 | - int64(data[6])<<8 | int64(data[7]) -} - -func marshalBool(info TypeInfo, value interface{}) ([]byte, error) { - switch v := value.(type) { - case Marshaler: - return v.MarshalCQL(info) - case unsetColumn: - return nil, nil - case bool: - return encBool(v), nil - } - - if value == nil { - return nil, nil - } - - rv := reflect.ValueOf(value) - switch rv.Type().Kind() { - case reflect.Bool: - return encBool(rv.Bool()), nil - } - return nil, marshalErrorf("can not marshal %T into %s", value, info) -} - -func encBool(v bool) []byte { - if v { - return []byte{1} - } - return []byte{0} -} - -func unmarshalBool(info TypeInfo, data []byte, value interface{}) error { - switch v := value.(type) { - case Unmarshaler: - return v.UnmarshalCQL(info, data) - case *bool: - *v = decBool(data) - return nil - } - rv := reflect.ValueOf(value) - if rv.Kind() != reflect.Ptr { - return unmarshalErrorf("can not unmarshal into non-pointer %T", value) - } - rv = rv.Elem() - switch rv.Type().Kind() { - case reflect.Bool: - rv.SetBool(decBool(data)) - return nil - } - return unmarshalErrorf("can not unmarshal %s into %T", info, value) -} - -func decBool(v []byte) bool { - if len(v) == 0 { - return false - } - return v[0] != 0 -} - -func marshalFloat(info TypeInfo, value interface{}) ([]byte, error) { - switch v := value.(type) { - case Marshaler: - return v.MarshalCQL(info) - case unsetColumn: - return nil, nil - case float32: - return encInt(int32(math.Float32bits(v))), nil - } - - if value == nil { - return nil, nil - } - - rv := reflect.ValueOf(value) - switch rv.Type().Kind() { - case reflect.Float32: - return encInt(int32(math.Float32bits(float32(rv.Float())))), nil - } - return nil, marshalErrorf("can not marshal %T into %s", value, info) -} - -func unmarshalFloat(info TypeInfo, data []byte, value interface{}) error { - switch v := value.(type) { - case Unmarshaler: - return v.UnmarshalCQL(info, data) - case *float32: - *v = math.Float32frombits(uint32(decInt(data))) - return nil - } - rv := reflect.ValueOf(value) - if rv.Kind() != reflect.Ptr { - return unmarshalErrorf("can not unmarshal into non-pointer %T", value) - } - rv = rv.Elem() - switch rv.Type().Kind() { - case reflect.Float32: - rv.SetFloat(float64(math.Float32frombits(uint32(decInt(data))))) - return nil - } - return unmarshalErrorf("can not unmarshal %s into %T", info, value) -} - -func marshalDouble(info TypeInfo, value interface{}) ([]byte, error) { - switch v := value.(type) { - case Marshaler: - return v.MarshalCQL(info) - case unsetColumn: - return nil, nil - case float64: - return encBigInt(int64(math.Float64bits(v))), nil - } - if value == nil { - return nil, nil - } - rv := reflect.ValueOf(value) - switch rv.Type().Kind() { - case reflect.Float64: - return encBigInt(int64(math.Float64bits(rv.Float()))), nil - } - return nil, marshalErrorf("can not marshal %T into %s", value, info) -} - -func unmarshalDouble(info TypeInfo, data []byte, value interface{}) error { - switch v := value.(type) { - case Unmarshaler: - return v.UnmarshalCQL(info, data) - case *float64: - *v = math.Float64frombits(uint64(decBigInt(data))) - return nil - } - rv := reflect.ValueOf(value) - if rv.Kind() != reflect.Ptr { - return unmarshalErrorf("can not unmarshal into non-pointer %T", value) - } - rv = rv.Elem() - switch rv.Type().Kind() { - case reflect.Float64: - rv.SetFloat(math.Float64frombits(uint64(decBigInt(data)))) - return nil - } - return unmarshalErrorf("can not unmarshal %s into %T", info, value) -} - -func marshalDecimal(info TypeInfo, value interface{}) ([]byte, error) { - if value == nil { - return nil, nil - } - - switch v := value.(type) { - case Marshaler: - return v.MarshalCQL(info) - case unsetColumn: - return nil, nil - case inf.Dec: - unscaled := encBigInt2C(v.UnscaledBig()) - if unscaled == nil { - return nil, marshalErrorf("can not marshal %T into %s", value, info) - } - - buf := make([]byte, 4+len(unscaled)) - copy(buf[0:4], encInt(int32(v.Scale()))) - copy(buf[4:], unscaled) - return buf, nil - } - return nil, marshalErrorf("can not marshal %T into %s", value, info) -} - -func unmarshalDecimal(info TypeInfo, data []byte, value interface{}) error { - switch v := value.(type) { - case Unmarshaler: - return v.UnmarshalCQL(info, data) - case *inf.Dec: - if len(data) < 4 { - return unmarshalErrorf("inf.Dec needs at least 4 bytes, while value has only %d", len(data)) - } - scale := decInt(data[0:4]) - unscaled := decBigInt2C(data[4:], nil) - *v = *inf.NewDecBig(unscaled, inf.Scale(scale)) - return nil - } - return unmarshalErrorf("can not unmarshal %s into %T", info, value) -} - -// decBigInt2C sets the value of n to the big-endian two's complement -// value stored in the given data. If data[0]&80 != 0, the number -// is negative. If data is empty, the result will be 0. -func decBigInt2C(data []byte, n *big.Int) *big.Int { - if n == nil { - n = new(big.Int) - } - n.SetBytes(data) - if len(data) > 0 && data[0]&0x80 > 0 { - n.Sub(n, new(big.Int).Lsh(bigOne, uint(len(data))*8)) - } - return n -} - -// encBigInt2C returns the big-endian two's complement -// form of n. -func encBigInt2C(n *big.Int) []byte { - switch n.Sign() { - case 0: - return []byte{0} - case 1: - b := n.Bytes() - if b[0]&0x80 > 0 { - b = append([]byte{0}, b...) - } - return b - case -1: - length := uint(n.BitLen()/8+1) * 8 - b := new(big.Int).Add(n, new(big.Int).Lsh(bigOne, length)).Bytes() - // When the most significant bit is on a byte - // boundary, we can get some extra significant - // bits, so strip them off when that happens. - if len(b) >= 2 && b[0] == 0xff && b[1]&0x80 != 0 { - b = b[1:] - } - return b - } - return nil -} - -func marshalTime(info TypeInfo, value interface{}) ([]byte, error) { - switch v := value.(type) { - case Marshaler: - return v.MarshalCQL(info) - case unsetColumn: - return nil, nil - case int64: - return encBigInt(v), nil - case time.Duration: - return encBigInt(v.Nanoseconds()), nil - } - - if value == nil { - return nil, nil - } - - rv := reflect.ValueOf(value) - switch rv.Type().Kind() { - case reflect.Int64: - return encBigInt(rv.Int()), nil - } - return nil, marshalErrorf("can not marshal %T into %s", value, info) -} - -func marshalTimestamp(info TypeInfo, value interface{}) ([]byte, error) { - switch v := value.(type) { - case Marshaler: - return v.MarshalCQL(info) - case unsetColumn: - return nil, nil - case int64: - return encBigInt(v), nil - case time.Time: - if v.IsZero() { - return []byte{}, nil - } - x := int64(v.UTC().Unix()*1e3) + int64(v.UTC().Nanosecond()/1e6) - return encBigInt(x), nil - } - - if value == nil { - return nil, nil - } - - rv := reflect.ValueOf(value) - switch rv.Type().Kind() { - case reflect.Int64: - return encBigInt(rv.Int()), nil - } - return nil, marshalErrorf("can not marshal %T into %s", value, info) -} - -func unmarshalTime(info TypeInfo, data []byte, value interface{}) error { - switch v := value.(type) { - case Unmarshaler: - return v.UnmarshalCQL(info, data) - case *int64: - *v = decBigInt(data) - return nil - case *time.Duration: - *v = time.Duration(decBigInt(data)) - return nil - } - - rv := reflect.ValueOf(value) - if rv.Kind() != reflect.Ptr { - return unmarshalErrorf("can not unmarshal into non-pointer %T", value) - } - rv = rv.Elem() - switch rv.Type().Kind() { - case reflect.Int64: - rv.SetInt(decBigInt(data)) - return nil - } - return unmarshalErrorf("can not unmarshal %s into %T", info, value) -} - -func unmarshalTimestamp(info TypeInfo, data []byte, value interface{}) error { - switch v := value.(type) { - case Unmarshaler: - return v.UnmarshalCQL(info, data) - case *int64: - *v = decBigInt(data) - return nil - case *time.Time: - if len(data) == 0 { - *v = time.Time{} - return nil - } - x := decBigInt(data) - sec := x / 1000 - nsec := (x - sec*1000) * 1000000 - *v = time.Unix(sec, nsec).In(time.UTC) - return nil - } - - rv := reflect.ValueOf(value) - if rv.Kind() != reflect.Ptr { - return unmarshalErrorf("can not unmarshal into non-pointer %T", value) - } - rv = rv.Elem() - switch rv.Type().Kind() { - case reflect.Int64: - rv.SetInt(decBigInt(data)) - return nil - } - return unmarshalErrorf("can not unmarshal %s into %T", info, value) -} - -const millisecondsInADay int64 = 24 * 60 * 60 * 1000 - -func marshalDate(info TypeInfo, value interface{}) ([]byte, error) { - var timestamp int64 - switch v := value.(type) { - case Marshaler: - return v.MarshalCQL(info) - case unsetColumn: - return nil, nil - case int64: - timestamp = v - x := timestamp/millisecondsInADay + int64(1<<31) - return encInt(int32(x)), nil - case time.Time: - if v.IsZero() { - return []byte{}, nil - } - timestamp = int64(v.UTC().Unix()*1e3) + int64(v.UTC().Nanosecond()/1e6) - x := timestamp/millisecondsInADay + int64(1<<31) - return encInt(int32(x)), nil - case *time.Time: - if v.IsZero() { - return []byte{}, nil - } - timestamp = int64(v.UTC().Unix()*1e3) + int64(v.UTC().Nanosecond()/1e6) - x := timestamp/millisecondsInADay + int64(1<<31) - return encInt(int32(x)), nil - case string: - if v == "" { - return []byte{}, nil - } - t, err := time.Parse("2006-01-02", v) - if err != nil { - return nil, marshalErrorf("can not marshal %T into %s, date layout must be '2006-01-02'", value, info) - } - timestamp = int64(t.UTC().Unix()*1e3) + int64(t.UTC().Nanosecond()/1e6) - x := timestamp/millisecondsInADay + int64(1<<31) - return encInt(int32(x)), nil - } - - if value == nil { - return nil, nil - } - return nil, marshalErrorf("can not marshal %T into %s", value, info) -} - -func unmarshalDate(info TypeInfo, data []byte, value interface{}) error { - switch v := value.(type) { - case Unmarshaler: - return v.UnmarshalCQL(info, data) - case *time.Time: - if len(data) == 0 { - *v = time.Time{} - return nil - } - var origin uint32 = 1 << 31 - var current uint32 = binary.BigEndian.Uint32(data) - timestamp := (int64(current) - int64(origin)) * millisecondsInADay - *v = time.UnixMilli(timestamp).In(time.UTC) - return nil - case *string: - if len(data) == 0 { - *v = "" - return nil - } - var origin uint32 = 1 << 31 - var current uint32 = binary.BigEndian.Uint32(data) - timestamp := (int64(current) - int64(origin)) * millisecondsInADay - *v = time.UnixMilli(timestamp).In(time.UTC).Format("2006-01-02") - return nil - } - return unmarshalErrorf("can not unmarshal %s into %T", info, value) -} - -func marshalDuration(info TypeInfo, value interface{}) ([]byte, error) { - switch v := value.(type) { - case Marshaler: - return v.MarshalCQL(info) - case unsetColumn: - return nil, nil - case int64: - return encVints(0, 0, v), nil - case time.Duration: - return encVints(0, 0, v.Nanoseconds()), nil - case string: - d, err := time.ParseDuration(v) - if err != nil { - return nil, err - } - return encVints(0, 0, d.Nanoseconds()), nil - case Duration: - return encVints(v.Months, v.Days, v.Nanoseconds), nil - } - - if value == nil { - return nil, nil - } - - rv := reflect.ValueOf(value) - switch rv.Type().Kind() { - case reflect.Int64: - return encBigInt(rv.Int()), nil - } - return nil, marshalErrorf("can not marshal %T into %s", value, info) -} - -func unmarshalDuration(info TypeInfo, data []byte, value interface{}) error { - switch v := value.(type) { - case Unmarshaler: - return v.UnmarshalCQL(info, data) - case *Duration: - if len(data) == 0 { - *v = Duration{ - Months: 0, - Days: 0, - Nanoseconds: 0, - } - return nil - } - months, days, nanos, err := decVints(data) - if err != nil { - return unmarshalErrorf("failed to unmarshal %s into %T: %s", info, value, err.Error()) - } - *v = Duration{ - Months: months, - Days: days, - Nanoseconds: nanos, - } - return nil - } - return unmarshalErrorf("can not unmarshal %s into %T", info, value) -} - -func decVints(data []byte) (int32, int32, int64, error) { - month, i, err := decVint(data, 0) - if err != nil { - return 0, 0, 0, fmt.Errorf("failed to extract month: %s", err.Error()) - } - days, i, err := decVint(data, i) - if err != nil { - return 0, 0, 0, fmt.Errorf("failed to extract days: %s", err.Error()) - } - nanos, _, err := decVint(data, i) - if err != nil { - return 0, 0, 0, fmt.Errorf("failed to extract nanoseconds: %s", err.Error()) - } - return int32(month), int32(days), nanos, err -} - -func decVint(data []byte, start int) (int64, int, error) { - if len(data) <= start { - return 0, 0, errors.New("unexpected eof") - } - firstByte := data[start] - if firstByte&0x80 == 0 { - return decIntZigZag(uint64(firstByte)), start + 1, nil - } - numBytes := bits.LeadingZeros32(uint32(^firstByte)) - 24 - ret := uint64(firstByte & (0xff >> uint(numBytes))) - if len(data) < start+numBytes+1 { - return 0, 0, fmt.Errorf("data expect to have %d bytes, but it has only %d", start+numBytes+1, len(data)) - } - for i := start; i < start+numBytes; i++ { - ret <<= 8 - ret |= uint64(data[i+1] & 0xff) - } - return decIntZigZag(ret), start + numBytes + 1, nil -} - -func decIntZigZag(n uint64) int64 { - return int64((n >> 1) ^ -(n & 1)) -} - -func encIntZigZag(n int64) uint64 { - return uint64((n >> 63) ^ (n << 1)) -} - -func encVints(months int32, seconds int32, nanos int64) []byte { - buf := append(encVint(int64(months)), encVint(int64(seconds))...) - return append(buf, encVint(nanos)...) -} - -func encVint(v int64) []byte { - vEnc := encIntZigZag(v) - lead0 := bits.LeadingZeros64(vEnc) - numBytes := (639 - lead0*9) >> 6 - - // It can be 1 or 0 is v ==0 - if numBytes <= 1 { - return []byte{byte(vEnc)} - } - extraBytes := numBytes - 1 - var buf = make([]byte, numBytes) - for i := extraBytes; i >= 0; i-- { - buf[i] = byte(vEnc) - vEnc >>= 8 - } - buf[0] |= byte(^(0xff >> uint(extraBytes))) - return buf -} - -func writeCollectionSize(info CollectionType, n int, buf *bytes.Buffer) error { - if info.proto > protoVersion2 { - if n > math.MaxInt32 { - return marshalErrorf("marshal: collection too large") - } - - buf.WriteByte(byte(n >> 24)) - buf.WriteByte(byte(n >> 16)) - buf.WriteByte(byte(n >> 8)) - buf.WriteByte(byte(n)) - } else { - if n > math.MaxUint16 { - return marshalErrorf("marshal: collection too large") - } - - buf.WriteByte(byte(n >> 8)) - buf.WriteByte(byte(n)) - } - - return nil -} - -func marshalList(info TypeInfo, value interface{}) ([]byte, error) { - listInfo, ok := info.(CollectionType) - if !ok { - return nil, marshalErrorf("marshal: can not marshal non collection type into list") - } - - if value == nil { - return nil, nil - } else if _, ok := value.(unsetColumn); ok { - return nil, nil - } - - rv := reflect.ValueOf(value) - t := rv.Type() - k := t.Kind() - if k == reflect.Slice && rv.IsNil() { - return nil, nil - } - - switch k { - case reflect.Slice, reflect.Array: - buf := &bytes.Buffer{} - n := rv.Len() - - if err := writeCollectionSize(listInfo, n, buf); err != nil { - return nil, err - } - - for i := 0; i < n; i++ { - item, err := Marshal(listInfo.Elem, rv.Index(i).Interface()) - if err != nil { - return nil, err - } - itemLen := len(item) - // Set the value to null for supported protocols - if item == nil && listInfo.proto > protoVersion2 { - itemLen = -1 - } - if err := writeCollectionSize(listInfo, itemLen, buf); err != nil { - return nil, err - } - buf.Write(item) - } - return buf.Bytes(), nil - case reflect.Map: - elem := t.Elem() - if elem.Kind() == reflect.Struct && elem.NumField() == 0 { - rkeys := rv.MapKeys() - keys := make([]interface{}, len(rkeys)) - for i := 0; i < len(keys); i++ { - keys[i] = rkeys[i].Interface() - } - return marshalList(listInfo, keys) - } - } - return nil, marshalErrorf("can not marshal %T into %s", value, info) -} - -func readCollectionSize(info CollectionType, data []byte) (size, read int, err error) { - if info.proto > protoVersion2 { - if len(data) < 4 { - return 0, 0, unmarshalErrorf("unmarshal list: unexpected eof") - } - size = int(int32(data[0])<<24 | int32(data[1])<<16 | int32(data[2])<<8 | int32(data[3])) - read = 4 - } else { - if len(data) < 2 { - return 0, 0, unmarshalErrorf("unmarshal list: unexpected eof") - } - size = int(data[0])<<8 | int(data[1]) - read = 2 - } - return -} - -func unmarshalList(info TypeInfo, data []byte, value interface{}) error { - listInfo, ok := info.(CollectionType) - if !ok { - return unmarshalErrorf("unmarshal: can not unmarshal none collection type into list") - } - - rv := reflect.ValueOf(value) - if rv.Kind() != reflect.Ptr { - return unmarshalErrorf("can not unmarshal into non-pointer %T", value) - } - rv = rv.Elem() - t := rv.Type() - k := t.Kind() - - switch k { - case reflect.Slice, reflect.Array: - if data == nil { - if k == reflect.Array { - return unmarshalErrorf("unmarshal list: can not store nil in array value") - } - if rv.IsNil() { - return nil - } - rv.Set(reflect.Zero(t)) - return nil - } - n, p, err := readCollectionSize(listInfo, data) - if err != nil { - return err - } - data = data[p:] - if k == reflect.Array { - if rv.Len() != n { - return unmarshalErrorf("unmarshal list: array with wrong size") - } - } else { - rv.Set(reflect.MakeSlice(t, n, n)) - } - for i := 0; i < n; i++ { - m, p, err := readCollectionSize(listInfo, data) - if err != nil { - return err - } - data = data[p:] - // In case m < 0, the value is null, and unmarshalData should be nil. - var unmarshalData []byte - if m >= 0 { - if len(data) < m { - return unmarshalErrorf("unmarshal list: unexpected eof") - } - unmarshalData = data[:m] - data = data[m:] - } - if err := Unmarshal(listInfo.Elem, unmarshalData, rv.Index(i).Addr().Interface()); err != nil { - return err - } - } - return nil - } - return unmarshalErrorf("can not unmarshal %s into %T", info, value) -} - -func marshalMap(info TypeInfo, value interface{}) ([]byte, error) { - mapInfo, ok := info.(CollectionType) - if !ok { - return nil, marshalErrorf("marshal: can not marshal none collection type into map") - } - - if value == nil { - return nil, nil - } else if _, ok := value.(unsetColumn); ok { - return nil, nil - } - - rv := reflect.ValueOf(value) - - t := rv.Type() - if t.Kind() != reflect.Map { - return nil, marshalErrorf("can not marshal %T into %s", value, info) - } - - if rv.IsNil() { - return nil, nil - } - - buf := &bytes.Buffer{} - n := rv.Len() +var Unmarshal = protocol.Unmarshal - if err := writeCollectionSize(mapInfo, n, buf); err != nil { - return nil, err - } +type UDTMarshaler = protocol.UDTMarshaler - keys := rv.MapKeys() - for _, key := range keys { - item, err := Marshal(mapInfo.Key, key.Interface()) - if err != nil { - return nil, err - } - itemLen := len(item) - // Set the key to null for supported protocols - if item == nil && mapInfo.proto > protoVersion2 { - itemLen = -1 - } - if err := writeCollectionSize(mapInfo, itemLen, buf); err != nil { - return nil, err - } - buf.Write(item) +type UDTUnmarshaler = protocol.UDTUnmarshaler - item, err = Marshal(mapInfo.Elem, rv.MapIndex(key).Interface()) - if err != nil { - return nil, err - } - itemLen = len(item) - // Set the value to null for supported protocols - if item == nil && mapInfo.proto > protoVersion2 { - itemLen = -1 - } - if err := writeCollectionSize(mapInfo, itemLen, buf); err != nil { - return nil, err - } - buf.Write(item) - } - return buf.Bytes(), nil -} +type TypeInfo = protocol.TypeInfo -func unmarshalMap(info TypeInfo, data []byte, value interface{}) error { - mapInfo, ok := info.(CollectionType) - if !ok { - return unmarshalErrorf("unmarshal: can not unmarshal none collection type into map") - } +type NativeType = protocol.NativeType - rv := reflect.ValueOf(value) - if rv.Kind() != reflect.Ptr { - return unmarshalErrorf("can not unmarshal into non-pointer %T", value) - } - rv = rv.Elem() - t := rv.Type() - if t.Kind() != reflect.Map { - return unmarshalErrorf("can not unmarshal %s into %T", info, value) - } - if data == nil { - rv.Set(reflect.Zero(t)) - return nil - } - n, p, err := readCollectionSize(mapInfo, data) - if err != nil { - return err - } - if n < 0 { - return unmarshalErrorf("negative map size %d", n) - } - rv.Set(reflect.MakeMapWithSize(t, n)) - data = data[p:] - for i := 0; i < n; i++ { - m, p, err := readCollectionSize(mapInfo, data) - if err != nil { - return err - } - data = data[p:] - key := reflect.New(t.Key()) - // In case m < 0, the key is null, and unmarshalData should be nil. - var unmarshalData []byte - if m >= 0 { - if len(data) < m { - return unmarshalErrorf("unmarshal map: unexpected eof") - } - unmarshalData = data[:m] - data = data[m:] - } - if err := Unmarshal(mapInfo.Key, unmarshalData, key.Interface()); err != nil { - return err - } +type TupleTypeInfo = protocol.TupleTypeInfo - m, p, err = readCollectionSize(mapInfo, data) - if err != nil { - return err - } - data = data[p:] - val := reflect.New(t.Elem()) +type UDTField = protocol.UDTField - // In case m < 0, the value is null, and unmarshalData should be nil. - unmarshalData = nil - if m >= 0 { - if len(data) < m { - return unmarshalErrorf("unmarshal map: unexpected eof") - } - unmarshalData = data[:m] - data = data[m:] - } - if err := Unmarshal(mapInfo.Elem, unmarshalData, val.Interface()); err != nil { - return err - } - - rv.SetMapIndex(key.Elem(), val.Elem()) - } - return nil -} - -func marshalUUID(info TypeInfo, value interface{}) ([]byte, error) { - switch val := value.(type) { - case unsetColumn: - return nil, nil - case UUID: - return val.Bytes(), nil - case [16]byte: - return val[:], nil - case []byte: - if len(val) != 16 { - return nil, marshalErrorf("can not marshal []byte %d bytes long into %s, must be exactly 16 bytes long", len(val), info) - } - return val, nil - case string: - b, err := ParseUUID(val) - if err != nil { - return nil, err - } - return b[:], nil - } - - if value == nil { - return nil, nil - } - - return nil, marshalErrorf("can not marshal %T into %s", value, info) -} - -func unmarshalUUID(info TypeInfo, data []byte, value interface{}) error { - if len(data) == 0 { - switch v := value.(type) { - case *string: - *v = "" - case *[]byte: - *v = nil - case *UUID: - *v = UUID{} - default: - return unmarshalErrorf("can not unmarshal X %s into %T", info, value) - } - - return nil - } - - if len(data) != 16 { - return unmarshalErrorf("unable to parse UUID: UUIDs must be exactly 16 bytes long") - } - - switch v := value.(type) { - case *[16]byte: - copy((*v)[:], data) - return nil - case *UUID: - copy((*v)[:], data) - return nil - } - - u, err := UUIDFromBytes(data) - if err != nil { - return unmarshalErrorf("unable to parse UUID: %s", err) - } - - switch v := value.(type) { - case *string: - *v = u.String() - return nil - case *[]byte: - *v = u[:] - return nil - } - return unmarshalErrorf("can not unmarshal X %s into %T", info, value) -} - -func unmarshalTimeUUID(info TypeInfo, data []byte, value interface{}) error { - switch v := value.(type) { - case Unmarshaler: - return v.UnmarshalCQL(info, data) - case *time.Time: - id, err := UUIDFromBytes(data) - if err != nil { - return err - } else if id.Version() != 1 { - return unmarshalErrorf("invalid timeuuid") - } - *v = id.Time() - return nil - default: - return unmarshalUUID(info, data, value) - } -} - -func marshalInet(info TypeInfo, value interface{}) ([]byte, error) { - // we return either the 4 or 16 byte representation of an - // ip address here otherwise the db value will be prefixed - // with the remaining byte values e.g. ::ffff:127.0.0.1 and not 127.0.0.1 - switch val := value.(type) { - case unsetColumn: - return nil, nil - case net.IP: - t := val.To4() - if t == nil { - return val.To16(), nil - } - return t, nil - case string: - b := net.ParseIP(val) - if b != nil { - t := b.To4() - if t == nil { - return b.To16(), nil - } - return t, nil - } - return nil, marshalErrorf("cannot marshal. invalid ip string %s", val) - } - - if value == nil { - return nil, nil - } - - return nil, marshalErrorf("cannot marshal %T into %s", value, info) -} - -func unmarshalInet(info TypeInfo, data []byte, value interface{}) error { - switch v := value.(type) { - case Unmarshaler: - return v.UnmarshalCQL(info, data) - case *net.IP: - if x := len(data); !(x == 4 || x == 16) { - return unmarshalErrorf("cannot unmarshal %s into %T: invalid sized IP: got %d bytes not 4 or 16", info, value, x) - } - buf := copyBytes(data) - ip := net.IP(buf) - if v4 := ip.To4(); v4 != nil { - *v = v4 - return nil - } - *v = ip - return nil - case *string: - if len(data) == 0 { - *v = "" - return nil - } - ip := net.IP(data) - if v4 := ip.To4(); v4 != nil { - *v = v4.String() - return nil - } - *v = ip.String() - return nil - } - return unmarshalErrorf("cannot unmarshal %s into %T", info, value) -} - -func marshalTuple(info TypeInfo, value interface{}) ([]byte, error) { - tuple := info.(TupleTypeInfo) - switch v := value.(type) { - case unsetColumn: - return nil, unmarshalErrorf("Invalid request: UnsetValue is unsupported for tuples") - case []interface{}: - if len(v) != len(tuple.Elems) { - return nil, unmarshalErrorf("cannont marshal tuple: wrong number of elements") - } - - var buf []byte - for i, elem := range v { - if elem == nil { - buf = appendInt(buf, int32(-1)) - continue - } - - data, err := Marshal(tuple.Elems[i], elem) - if err != nil { - return nil, err - } - - n := len(data) - buf = appendInt(buf, int32(n)) - buf = append(buf, data...) - } - - return buf, nil - } - - rv := reflect.ValueOf(value) - t := rv.Type() - k := t.Kind() - - switch k { - case reflect.Struct: - if v := t.NumField(); v != len(tuple.Elems) { - return nil, marshalErrorf("can not marshal tuple into struct %v, not enough fields have %d need %d", t, v, len(tuple.Elems)) - } - - var buf []byte - for i, elem := range tuple.Elems { - field := rv.Field(i) - - if field.Kind() == reflect.Ptr && field.IsNil() { - buf = appendInt(buf, int32(-1)) - continue - } - - data, err := Marshal(elem, field.Interface()) - if err != nil { - return nil, err - } - - n := len(data) - buf = appendInt(buf, int32(n)) - buf = append(buf, data...) - } - - return buf, nil - case reflect.Slice, reflect.Array: - size := rv.Len() - if size != len(tuple.Elems) { - return nil, marshalErrorf("can not marshal tuple into %v of length %d need %d elements", k, size, len(tuple.Elems)) - } - - var buf []byte - for i, elem := range tuple.Elems { - item := rv.Index(i) - - if item.Kind() == reflect.Ptr && item.IsNil() { - buf = appendInt(buf, int32(-1)) - continue - } - - data, err := Marshal(elem, item.Interface()) - if err != nil { - return nil, err - } - - n := len(data) - buf = appendInt(buf, int32(n)) - buf = append(buf, data...) - } - - return buf, nil - } - - return nil, marshalErrorf("cannot marshal %T into %s", value, tuple) -} - -func readBytes(p []byte) ([]byte, []byte) { - // TODO: really should use a framer - size := readInt(p) - p = p[4:] - if size < 0 { - return nil, p - } - return p[:size], p[size:] -} - -// currently only support unmarshal into a list of values, this makes it possible -// to support tuples without changing the query API. In the future this can be extend -// to allow unmarshalling into custom tuple types. -func unmarshalTuple(info TypeInfo, data []byte, value interface{}) error { - if v, ok := value.(Unmarshaler); ok { - return v.UnmarshalCQL(info, data) - } - - tuple := info.(TupleTypeInfo) - switch v := value.(type) { - case []interface{}: - for i, elem := range tuple.Elems { - // each element inside data is a [bytes] - var p []byte - if len(data) >= 4 { - p, data = readBytes(data) - } - err := Unmarshal(elem, p, v[i]) - if err != nil { - return err - } - } - - return nil - } - - rv := reflect.ValueOf(value) - if rv.Kind() != reflect.Ptr { - return unmarshalErrorf("can not unmarshal into non-pointer %T", value) - } - - rv = rv.Elem() - t := rv.Type() - k := t.Kind() - - switch k { - case reflect.Struct: - if v := t.NumField(); v != len(tuple.Elems) { - return unmarshalErrorf("can not unmarshal tuple into struct %v, not enough fields have %d need %d", t, v, len(tuple.Elems)) - } - - for i, elem := range tuple.Elems { - var p []byte - if len(data) >= 4 { - p, data = readBytes(data) - } - - v, err := elem.NewWithError() - if err != nil { - return err - } - if err := Unmarshal(elem, p, v); err != nil { - return err - } - - switch rv.Field(i).Kind() { - case reflect.Ptr: - if p != nil { - rv.Field(i).Set(reflect.ValueOf(v)) - } else { - rv.Field(i).Set(reflect.Zero(reflect.TypeOf(v))) - } - default: - rv.Field(i).Set(reflect.ValueOf(v).Elem()) - } - } - - return nil - case reflect.Slice, reflect.Array: - if k == reflect.Array { - size := rv.Len() - if size != len(tuple.Elems) { - return unmarshalErrorf("can not unmarshal tuple into array of length %d need %d elements", size, len(tuple.Elems)) - } - } else { - rv.Set(reflect.MakeSlice(t, len(tuple.Elems), len(tuple.Elems))) - } - - for i, elem := range tuple.Elems { - var p []byte - if len(data) >= 4 { - p, data = readBytes(data) - } - - v, err := elem.NewWithError() - if err != nil { - return err - } - if err := Unmarshal(elem, p, v); err != nil { - return err - } - - switch rv.Index(i).Kind() { - case reflect.Ptr: - if p != nil { - rv.Index(i).Set(reflect.ValueOf(v)) - } else { - rv.Index(i).Set(reflect.Zero(reflect.TypeOf(v))) - } - default: - rv.Index(i).Set(reflect.ValueOf(v).Elem()) - } - } - - return nil - } - - return unmarshalErrorf("cannot unmarshal %s into %T", info, value) -} - -// UDTMarshaler is an interface which should be implemented by users wishing to -// handle encoding UDT types to sent to Cassandra. Note: due to current implentations -// methods defined for this interface must be value receivers not pointer receivers. -type UDTMarshaler interface { - // MarshalUDT will be called for each field in the the UDT returned by Cassandra, - // the implementor should marshal the type to return by for example calling - // Marshal. - MarshalUDT(name string, info TypeInfo) ([]byte, error) -} - -// UDTUnmarshaler should be implemented by users wanting to implement custom -// UDT unmarshaling. -type UDTUnmarshaler interface { - // UnmarshalUDT will be called for each field in the UDT return by Cassandra, - // the implementor should unmarshal the data into the value of their chosing, - // for example by calling Unmarshal. - UnmarshalUDT(name string, info TypeInfo, data []byte) error -} - -func marshalUDT(info TypeInfo, value interface{}) ([]byte, error) { - udt := info.(UDTTypeInfo) - - switch v := value.(type) { - case Marshaler: - return v.MarshalCQL(info) - case unsetColumn: - return nil, unmarshalErrorf("invalid request: UnsetValue is unsupported for user defined types") - case UDTMarshaler: - var buf []byte - for _, e := range udt.Elements { - data, err := v.MarshalUDT(e.Name, e.Type) - if err != nil { - return nil, err - } - - buf = appendBytes(buf, data) - } - - return buf, nil - case map[string]interface{}: - var buf []byte - for _, e := range udt.Elements { - val, ok := v[e.Name] - - var data []byte - - if ok { - var err error - data, err = Marshal(e.Type, val) - if err != nil { - return nil, err - } - } - - buf = appendBytes(buf, data) - } - - return buf, nil - } - - k := reflect.ValueOf(value) - if k.Kind() == reflect.Ptr { - if k.IsNil() { - return nil, marshalErrorf("cannot marshal %T into %s", value, info) - } - k = k.Elem() - } - - if k.Kind() != reflect.Struct || !k.IsValid() { - return nil, marshalErrorf("cannot marshal %T into %s", value, info) - } - - fields := make(map[string]reflect.Value) - t := reflect.TypeOf(value) - for i := 0; i < t.NumField(); i++ { - sf := t.Field(i) - - if tag := sf.Tag.Get("cql"); tag != "" { - fields[tag] = k.Field(i) - } - } - - var buf []byte - for _, e := range udt.Elements { - f, ok := fields[e.Name] - if !ok { - f = k.FieldByName(e.Name) - } - - var data []byte - if f.IsValid() && f.CanInterface() { - var err error - data, err = Marshal(e.Type, f.Interface()) - if err != nil { - return nil, err - } - } - - buf = appendBytes(buf, data) - } - - return buf, nil -} - -func unmarshalUDT(info TypeInfo, data []byte, value interface{}) error { - switch v := value.(type) { - case Unmarshaler: - return v.UnmarshalCQL(info, data) - case UDTUnmarshaler: - udt := info.(UDTTypeInfo) - - for id, e := range udt.Elements { - if len(data) == 0 { - return nil - } - if len(data) < 4 { - return unmarshalErrorf("can not unmarshal %s: field [%d]%s: unexpected eof", info, id, e.Name) - } - - var p []byte - p, data = readBytes(data) - if err := v.UnmarshalUDT(e.Name, e.Type, p); err != nil { - return err - } - } - - return nil - case *map[string]interface{}: - udt := info.(UDTTypeInfo) - - rv := reflect.ValueOf(value) - if rv.Kind() != reflect.Ptr { - return unmarshalErrorf("can not unmarshal into non-pointer %T", value) - } - - rv = rv.Elem() - t := rv.Type() - if t.Kind() != reflect.Map { - return unmarshalErrorf("can not unmarshal %s into %T", info, value) - } else if data == nil { - rv.Set(reflect.Zero(t)) - return nil - } - - rv.Set(reflect.MakeMap(t)) - m := *v - - for id, e := range udt.Elements { - if len(data) == 0 { - return nil - } - if len(data) < 4 { - return unmarshalErrorf("can not unmarshal %s: field [%d]%s: unexpected eof", info, id, e.Name) - } - - valType, err := goType(e.Type) - if err != nil { - return unmarshalErrorf("can not unmarshal %s: %v", info, err) - } - - val := reflect.New(valType) - - var p []byte - p, data = readBytes(data) - - if err := Unmarshal(e.Type, p, val.Interface()); err != nil { - return err - } - - m[e.Name] = val.Elem().Interface() - } - - return nil - } - - rv := reflect.ValueOf(value) - if rv.Kind() != reflect.Ptr { - return unmarshalErrorf("can not unmarshal into non-pointer %T", value) - } - k := rv.Elem() - if k.Kind() != reflect.Struct || !k.IsValid() { - return unmarshalErrorf("cannot unmarshal %s into %T", info, value) - } - - if len(data) == 0 { - if k.CanSet() { - k.Set(reflect.Zero(k.Type())) - } - - return nil - } - - t := k.Type() - fields := make(map[string]reflect.Value, t.NumField()) - for i := 0; i < t.NumField(); i++ { - sf := t.Field(i) - - if tag := sf.Tag.Get("cql"); tag != "" { - fields[tag] = k.Field(i) - } - } - - udt := info.(UDTTypeInfo) - for id, e := range udt.Elements { - if len(data) == 0 { - return nil - } - if len(data) < 4 { - // UDT def does not match the column value - return unmarshalErrorf("can not unmarshal %s: field [%d]%s: unexpected eof", info, id, e.Name) - } - - var p []byte - p, data = readBytes(data) - - f, ok := fields[e.Name] - if !ok { - f = k.FieldByName(e.Name) - if f == emptyValue { - // skip fields which exist in the UDT but not in - // the struct passed in - continue - } - } - - if !f.IsValid() || !f.CanAddr() { - return unmarshalErrorf("cannot unmarshal %s into %T: field %v is not valid", info, value, e.Name) - } - - fk := f.Addr().Interface() - if err := Unmarshal(e.Type, p, fk); err != nil { - return err - } - } - - return nil -} - -// TypeInfo describes a Cassandra specific data type. -type TypeInfo interface { - Type() Type - Version() byte - Custom() string - - // New creates a pointer to an empty version of whatever type - // is referenced by the TypeInfo receiver. - // - // If there is no corresponding Go type for the CQL type, New panics. - // - // Deprecated: Use NewWithError instead. - New() interface{} - - // NewWithError creates a pointer to an empty version of whatever type - // is referenced by the TypeInfo receiver. - // - // If there is no corresponding Go type for the CQL type, NewWithError returns an error. - NewWithError() (interface{}, error) -} - -type NativeType struct { - proto byte - typ Type - custom string // only used for TypeCustom -} - -func NewNativeType(proto byte, typ Type, custom string) NativeType { - return NativeType{proto, typ, custom} -} - -func (t NativeType) NewWithError() (interface{}, error) { - typ, err := goType(t) - if err != nil { - return nil, err - } - return reflect.New(typ).Interface(), nil -} - -func (t NativeType) New() interface{} { - val, err := t.NewWithError() - if err != nil { - panic(err.Error()) - } - return val -} - -func (s NativeType) Type() Type { - return s.typ -} - -func (s NativeType) Version() byte { - return s.proto -} - -func (s NativeType) Custom() string { - return s.custom -} - -func (s NativeType) String() string { - switch s.typ { - case TypeCustom: - return fmt.Sprintf("%s(%s)", s.typ, s.custom) - default: - return s.typ.String() - } -} - -type CollectionType struct { - NativeType - Key TypeInfo // only used for TypeMap - Elem TypeInfo // only used for TypeMap, TypeList and TypeSet -} - -func (t CollectionType) NewWithError() (interface{}, error) { - typ, err := goType(t) - if err != nil { - return nil, err - } - return reflect.New(typ).Interface(), nil -} - -func (t CollectionType) New() interface{} { - val, err := t.NewWithError() - if err != nil { - panic(err.Error()) - } - return val -} - -func (c CollectionType) String() string { - switch c.typ { - case TypeMap: - return fmt.Sprintf("%s(%s, %s)", c.typ, c.Key, c.Elem) - case TypeList, TypeSet: - return fmt.Sprintf("%s(%s)", c.typ, c.Elem) - case TypeCustom: - return fmt.Sprintf("%s(%s)", c.typ, c.custom) - default: - return c.typ.String() - } -} - -type TupleTypeInfo struct { - NativeType - Elems []TypeInfo -} - -func (t TupleTypeInfo) String() string { - var buf bytes.Buffer - buf.WriteString(fmt.Sprintf("%s(", t.typ)) - for _, elem := range t.Elems { - buf.WriteString(fmt.Sprintf("%s, ", elem)) - } - buf.Truncate(buf.Len() - 2) - buf.WriteByte(')') - return buf.String() -} - -func (t TupleTypeInfo) NewWithError() (interface{}, error) { - typ, err := goType(t) - if err != nil { - return nil, err - } - return reflect.New(typ).Interface(), nil -} - -func (t TupleTypeInfo) New() interface{} { - val, err := t.NewWithError() - if err != nil { - panic(err.Error()) - } - return val -} - -type UDTField struct { - Name string - Type TypeInfo -} - -type UDTTypeInfo struct { - NativeType - KeySpace string - Name string - Elements []UDTField -} - -func (u UDTTypeInfo) NewWithError() (interface{}, error) { - typ, err := goType(u) - if err != nil { - return nil, err - } - return reflect.New(typ).Interface(), nil -} - -func (u UDTTypeInfo) New() interface{} { - val, err := u.NewWithError() - if err != nil { - panic(err.Error()) - } - return val -} - -func (u UDTTypeInfo) String() string { - buf := &bytes.Buffer{} - - fmt.Fprintf(buf, "%s.%s{", u.KeySpace, u.Name) - first := true - for _, e := range u.Elements { - if !first { - fmt.Fprint(buf, ",") - } else { - first = false - } - - fmt.Fprintf(buf, "%s=%v", e.Name, e.Type) - } - fmt.Fprint(buf, "}") - - return buf.String() -} - -// String returns a human readable name for the Cassandra datatype -// described by t. -// Type is the identifier of a Cassandra internal datatype. -type Type int +type UDTTypeInfo = protocol.UDTTypeInfo +// // String returns a human readable name for the Cassandra datatype +// // described by t. +// // Type is the identifier of a Cassandra internal datatype. +// type Type int const ( - TypeCustom Type = 0x0000 - TypeAscii Type = 0x0001 - TypeBigInt Type = 0x0002 - TypeBlob Type = 0x0003 - TypeBoolean Type = 0x0004 - TypeCounter Type = 0x0005 - TypeDecimal Type = 0x0006 - TypeDouble Type = 0x0007 - TypeFloat Type = 0x0008 - TypeInt Type = 0x0009 - TypeText Type = 0x000A - TypeTimestamp Type = 0x000B - TypeUUID Type = 0x000C - TypeVarchar Type = 0x000D - TypeVarint Type = 0x000E - TypeTimeUUID Type = 0x000F - TypeInet Type = 0x0010 - TypeDate Type = 0x0011 - TypeTime Type = 0x0012 - TypeSmallInt Type = 0x0013 - TypeTinyInt Type = 0x0014 - TypeDuration Type = 0x0015 - TypeList Type = 0x0020 - TypeMap Type = 0x0021 - TypeSet Type = 0x0022 - TypeUDT Type = 0x0030 - TypeTuple Type = 0x0031 + TypeCustom = protocol.TypeCustom + TypeAscii = protocol.TypeAscii + TypeBigInt = protocol.TypeBigInt + TypeBlob = protocol.TypeBlob + TypeBoolean = protocol.TypeBoolean + TypeCounter = protocol.TypeCounter + TypeDecimal = protocol.TypeDecimal + TypeDouble = protocol.TypeDouble + TypeFloat = protocol.TypeFloat + TypeInt = protocol.TypeInt + TypeText = protocol.TypeText + TypeTimestamp = protocol.TypeTimestamp + TypeUUID = protocol.TypeUUID + TypeVarchar = protocol.TypeVarchar + TypeVarint = protocol.TypeVarint + TypeTimeUUID = protocol.TypeTimeUUID + TypeInet = protocol.TypeInet + TypeDate = protocol.TypeDate + TypeTime = protocol.TypeTime + TypeSmallInt = protocol.TypeSmallInt + TypeTinyInt = protocol.TypeTinyInt + TypeDuration = protocol.TypeDuration + TypeList = protocol.TypeList + TypeMap = protocol.TypeMap + TypeSet = protocol.TypeSet + TypeUDT = protocol.TypeUDT + TypeTuple = protocol.TypeTuple ) -// String returns the name of the identifier. -func (t Type) String() string { - switch t { - case TypeCustom: - return "custom" - case TypeAscii: - return "ascii" - case TypeBigInt: - return "bigint" - case TypeBlob: - return "blob" - case TypeBoolean: - return "boolean" - case TypeCounter: - return "counter" - case TypeDecimal: - return "decimal" - case TypeDouble: - return "double" - case TypeFloat: - return "float" - case TypeInt: - return "int" - case TypeText: - return "text" - case TypeTimestamp: - return "timestamp" - case TypeUUID: - return "uuid" - case TypeVarchar: - return "varchar" - case TypeTimeUUID: - return "timeuuid" - case TypeInet: - return "inet" - case TypeDate: - return "date" - case TypeDuration: - return "duration" - case TypeTime: - return "time" - case TypeSmallInt: - return "smallint" - case TypeTinyInt: - return "tinyint" - case TypeList: - return "list" - case TypeMap: - return "map" - case TypeSet: - return "set" - case TypeVarint: - return "varint" - case TypeTuple: - return "tuple" - default: - return fmt.Sprintf("unknown_type_%d", t) - } -} - -type MarshalError string - -func (m MarshalError) Error() string { - return string(m) -} - -func marshalErrorf(format string, args ...interface{}) MarshalError { - return MarshalError(fmt.Sprintf(format, args...)) -} - -type UnmarshalError string - -func (m UnmarshalError) Error() string { - return string(m) -} +type MarshalError = protocol.MarshalError -func unmarshalErrorf(format string, args ...interface{}) UnmarshalError { - return UnmarshalError(fmt.Sprintf(format, args...)) -} +type UnmarshalError = protocol.UnmarshalError diff --git a/marshal_test.go b/marshal_test.go index 6c139e6bc..f96b52f55 100644 --- a/marshal_test.go +++ b/marshal_test.go @@ -1692,7 +1692,7 @@ var typeLookupTest = []struct { } func testType(t *testing.T, cassType string, expectedType Type) { - if computedType := getApacheCassandraType(apacheCassandraTypePrefix + cassType); computedType != expectedType { + if computedType := protocol.GetApacheCassandraType(protocol.ApacheCassandraTypePrefix + cassType); computedType != expectedType { t.Errorf("Cassandra custom type lookup for %s failed. Expected %s, got %s.", cassType, expectedType.String(), computedType.String()) } } @@ -2207,7 +2207,7 @@ func TestMarshalNil(t *testing.T) { func TestUnmarshalInetCopyBytes(t *testing.T) { data := []byte{127, 0, 0, 1} var ip net.IP - if err := unmarshalInet(NativeType{proto: 2, typ: TypeInet}, data, &ip); err != nil { + if err := UnmarshalInet(NativeType{proto: 2, typ: TypeInet}, data, &ip); err != nil { t.Fatal(err) } @@ -2221,7 +2221,7 @@ func TestUnmarshalInetCopyBytes(t *testing.T) { func TestUnmarshalDate(t *testing.T) { data := []uint8{0x80, 0x0, 0x43, 0x31} var date time.Time - if err := unmarshalDate(NativeType{proto: 2, typ: TypeDate}, data, &date); err != nil { + if err := UnmarshalDate(NativeType{proto: 2, typ: TypeDate}, data, &date); err != nil { t.Fatal(err) } @@ -2232,7 +2232,7 @@ func TestUnmarshalDate(t *testing.T) { return } var stringDate string - if err2 := unmarshalDate(NativeType{proto: 2, typ: TypeDate}, data, &stringDate); err2 != nil { + if err2 := UnmarshalDate(NativeType{proto: 2, typ: TypeDate}, data, &stringDate); err2 != nil { t.Fatal(err2) } if expectedDate != stringDate { @@ -2345,7 +2345,7 @@ func BenchmarkUnmarshalVarchar(b *testing.B) { b.ResetTimer() for i := 0; i < b.N; i++ { - if err := unmarshalVarchar(NativeType{}, src, &dst); err != nil { + if err := UnmarshalVarchar(NativeType{}, src, &dst); err != nil { b.Fatal(err) } } @@ -2489,7 +2489,7 @@ func BenchmarkUnmarshalUUID(b *testing.B) { b.ResetTimer() for i := 0; i < b.N; i++ { - if err := unmarshalUUID(ti, src, &dst); err != nil { + if err := UnmarshalUUID(ti, src, &dst); err != nil { b.Fatal(err) } } diff --git a/metadata.go b/metadata.go index 6eb798f8a..ec7ab59d8 100644 --- a/metadata.go +++ b/metadata.go @@ -32,6 +32,7 @@ import ( "encoding/hex" "encoding/json" "fmt" + "github.com/gocql/gocql/internal/protocol" "strconv" "strings" "sync" @@ -196,7 +197,7 @@ func (c ColumnKind) String() string { func (c *ColumnKind) UnmarshalCQL(typ TypeInfo, p []byte) error { if typ.Type() != TypeVarchar { - return unmarshalErrorf("unable to marshall %s into ColumnKind, expected Varchar", typ) + return protocol.UnmarshalErrorf("unable to marshall %s into ColumnKind, expected Varchar", typ) } kind, err := columnKindFromSchema(string(p)) @@ -383,7 +384,7 @@ func compileMetadata( col := &columns[i] // decode the validator for TypeInfo and order if col.ClusteringOrder != "" { // Cassandra 3.x+ - col.Type = getCassandraType(col.Validator, logger) + col.Type = protocol.GetCassandraType(col.Validator, logger) col.Order = ASC if col.ClusteringOrder == "desc" { col.Order = DESC @@ -948,10 +949,10 @@ func getColumnMetadata(session *Session, keyspaceName string) ([]ColumnMetadata, } func getTypeInfo(t string, logger StdLogger) TypeInfo { - if strings.HasPrefix(t, apacheCassandraTypePrefix) { - t = apacheToCassandraType(t) + if strings.HasPrefix(t, protocol.ApacheCassandraTypePrefix) { + t = protocol.ApacheToCassandraType(t) } - return getCassandraType(t, logger) + return protocol.GetCassandraType(t, logger) } func getViewsMetadata(session *Session, keyspaceName string) ([]ViewMetadata, error) { @@ -1234,8 +1235,8 @@ func (t *typeParser) parse() typeParserResult { isComposite: false, types: []TypeInfo{ NativeType{ - typ: TypeCustom, - custom: t.input, + Typ: TypeCustom, + Cust: t.input, }, }, reversed: []bool{false}, @@ -1311,18 +1312,18 @@ func (t *typeParser) parse() typeParserResult { func (class *typeParserClassNode) asTypeInfo() TypeInfo { if strings.HasPrefix(class.name, LIST_TYPE) { elem := class.params[0].class.asTypeInfo() - return CollectionType{ + return protocol.CollectionType{ NativeType: NativeType{ - typ: TypeList, + Typ: TypeList, }, Elem: elem, } } if strings.HasPrefix(class.name, SET_TYPE) { elem := class.params[0].class.asTypeInfo() - return CollectionType{ + return protocol.CollectionType{ NativeType: NativeType{ - typ: TypeSet, + Typ: TypeSet, }, Elem: elem, } @@ -1330,9 +1331,9 @@ func (class *typeParserClassNode) asTypeInfo() TypeInfo { if strings.HasPrefix(class.name, MAP_TYPE) { key := class.params[0].class.asTypeInfo() elem := class.params[1].class.asTypeInfo() - return CollectionType{ + return protocol.CollectionType{ NativeType: NativeType{ - typ: TypeMap, + Typ: TypeMap, }, Key: key, Elem: elem, @@ -1340,10 +1341,10 @@ func (class *typeParserClassNode) asTypeInfo() TypeInfo { } // must be a simple type or custom type - info := NativeType{typ: getApacheCassandraType(class.name)} - if info.typ == TypeCustom { + info := NativeType{Typ: protocol.GetApacheCassandraType(class.name)} + if info.Typ == TypeCustom { // add the entire class definition - info.custom = class.input + info.Cust = class.input } return info } diff --git a/query_executor.go b/query_executor.go index d6be02e53..c96f6d919 100644 --- a/query_executor.go +++ b/query_executor.go @@ -156,7 +156,7 @@ func (q *queryExecutor) do(ctx context.Context, qry ExecutableQuery, hostIter Ne // Update host switch iter.err { case context.Canceled, context.DeadlineExceeded, ErrNotFound: - // those errors represents logical errors, they should not count + // those internal_errors represents logical internal_errors, they should not count // toward removing a node from the pool selectedHost.Mark(nil) return iter diff --git a/session.go b/session.go index b884735c2..706d74e17 100644 --- a/session.go +++ b/session.go @@ -30,6 +30,7 @@ import ( "encoding/binary" "errors" "fmt" + "github.com/gocql/gocql/internal/protocol" "io" "net" "strings" @@ -645,23 +646,23 @@ func (s *Session) routingKeyInfo(ctx context.Context, stmt string) (*routingKeyI // TODO: it would be nice to mark hosts here but as we are not using the policies // to fetch hosts we cant - if info.request.colCount == 0 { + if info.request.ColCount == 0 { // no arguments, no routing key, and no error return nil, nil } - table := info.request.table - keyspace := info.request.keyspace + table := info.request.Table + keyspace := info.request.Keyspace - if len(info.request.pkeyColumns) > 0 { + if len(info.request.PkeyColumns) > 0 { // proto v4 dont need to calculate primary key columns - types := make([]TypeInfo, len(info.request.pkeyColumns)) - for i, col := range info.request.pkeyColumns { - types[i] = info.request.columns[col].TypeInfo + types := make([]TypeInfo, len(info.request.PkeyColumns)) + for i, col := range info.request.PkeyColumns { + types[i] = info.request.Columns[col].TypeInfo } routingKeyInfo := &routingKeyInfo{ - indexes: info.request.pkeyColumns, + indexes: info.request.PkeyColumns, types: types, keyspace: keyspace, table: table, @@ -672,7 +673,7 @@ func (s *Session) routingKeyInfo(ctx context.Context, stmt string) (*routingKeyI } var keyspaceMetadata *KeyspaceMetadata - keyspaceMetadata, inflight.err = s.KeyspaceMetadata(info.request.columns[0].Keyspace) + keyspaceMetadata, inflight.err = s.KeyspaceMetadata(info.request.Columns[0].Keyspace) if inflight.err != nil { // don't cache this error s.routingKeyInfoCache.Remove(stmt) @@ -705,7 +706,7 @@ func (s *Session) routingKeyInfo(ctx context.Context, stmt string) (*routingKeyI routingKeyInfo.indexes[keyIndex] = -1 // find the column in the query info - for argIndex, boundColumn := range info.request.columns { + for argIndex, boundColumn := range info.request.Columns { if keyColumn.Name == boundColumn.Name { // there may be many such bound columns, pick the first routingKeyInfo.indexes[keyIndex] = argIndex @@ -1429,12 +1430,12 @@ func (q *Query) releaseAfterExecution() { type Iter struct { err error pos int - meta resultMetadata + meta protocol.ResultMetadata numRows int next *nextIter host *HostInfo - framer *framer + framer *protocol.Framer closed int32 } @@ -1445,7 +1446,7 @@ func (iter *Iter) Host() *HostInfo { // Columns returns the name and type of the selected columns. func (iter *Iter) Columns() []ColumnInfo { - return iter.meta.columns + return iter.meta.Columns } type Scanner interface { @@ -1533,15 +1534,15 @@ func (is *iterScanner) Scan(dest ...interface{}) error { iter := is.iter // currently only support scanning into an expand tuple, such that its the same // as scanning in more values from a single column - if len(dest) != iter.meta.actualColCount { - return fmt.Errorf("gocql: not enough columns to scan into: have %d want %d", len(dest), iter.meta.actualColCount) + if len(dest) != iter.meta.ActualColCount { + return fmt.Errorf("gocql: not enough columns to scan into: have %d want %d", len(dest), iter.meta.ActualColCount) } // i is the current position in dest, could posible replace it and just use // slices of dest i := 0 var err error - for _, col := range iter.meta.columns { + for _, col := range iter.meta.Columns { var n int n, err = scanColumn(is.cols[i], col, dest[i:]) if err != nil { @@ -1569,11 +1570,11 @@ func (iter *Iter) Scanner() Scanner { return nil } - return &iterScanner{iter: iter, cols: make([][]byte, len(iter.meta.columns))} + return &iterScanner{iter: iter, cols: make([][]byte, len(iter.meta.Columns))} } func (iter *Iter) readColumn() ([]byte, error) { - return iter.framer.readBytesInternal() + return iter.framer.ReadBytesInternal() } // Scan consumes the next row of the iterator and copies the columns of the @@ -1583,7 +1584,7 @@ func (iter *Iter) readColumn() ([]byte, error) { // // Scan returns true if the row was successfully unmarshaled or false if the // end of the result set was reached or if an error occurred. Close should -// be called afterwards to retrieve any potential errors. +// be called afterwards to retrieve any potential internal_errors. func (iter *Iter) Scan(dest ...interface{}) bool { if iter.err != nil { return false @@ -1603,15 +1604,15 @@ func (iter *Iter) Scan(dest ...interface{}) bool { // currently only support scanning into an expand tuple, such that its the same // as scanning in more values from a single column - if len(dest) != iter.meta.actualColCount { - iter.err = fmt.Errorf("gocql: not enough columns to scan into: have %d want %d", len(dest), iter.meta.actualColCount) + if len(dest) != iter.meta.ActualColCount { + iter.err = fmt.Errorf("gocql: not enough columns to scan into: have %d want %d", len(dest), iter.meta.ActualColCount) return false } // i is the current position in dest, could posible replace it and just use // slices of dest i := 0 - for _, col := range iter.meta.columns { + for _, col := range iter.meta.Columns { colBytes, err := iter.readColumn() if err != nil { iter.err = err @@ -1639,7 +1640,7 @@ func (iter *Iter) Scan(dest ...interface{}) bool { // See https://datastax.github.io/java-driver/manual/custom_payloads/ func (iter *Iter) GetCustomPayload() map[string][]byte { if iter.framer != nil { - return iter.framer.customPayload + return iter.framer.CustomPayload } return nil } @@ -1649,12 +1650,12 @@ func (iter *Iter) GetCustomPayload() map[string][]byte { // This is only available starting with CQL Protocol v4. func (iter *Iter) Warnings() []string { if iter.framer != nil { - return iter.framer.header.warnings + return iter.framer.Header.Warnings } return nil } -// Close closes the iterator and returns any errors that happened during +// Close closes the iterator and returns any internal_errors that happened during // the query or the iteration. func (iter *Iter) Close() error { if atomic.CompareAndSwapInt32(&iter.closed, 0, 1) { @@ -1685,7 +1686,7 @@ func (iter *Iter) checkErrAndNotFound() error { // PageState return the current paging state for a query which can be used for // subsequent queries to resume paging this point. func (iter *Iter) PageState() []byte { - return iter.meta.pagingState + return iter.meta.PagingState } // NumRows returns the number of rows in this pagination, it will update when new @@ -2030,7 +2031,7 @@ func (b *Batch) releaseAfterExecution() { // that would race with speculative executions. } -type BatchType byte +type BatchType = protocol.BatchType const ( LoggedBatch BatchType = 0 @@ -2045,16 +2046,7 @@ type BatchEntry struct { binding func(q *QueryInfo) ([]interface{}, error) } -type ColumnInfo struct { - Keyspace string - Table string - Name string - TypeInfo TypeInfo -} - -func (c ColumnInfo) String() string { - return fmt.Sprintf("[column keyspace=%s table=%s name=%s type=%v]", c.Keyspace, c.Table, c.Name, c.TypeInfo) -} +type ColumnInfo = protocol.ColumnInfo // routing key indexes LRU cache type routingKeyInfoLRU struct { @@ -2185,7 +2177,7 @@ type ObservedQuery struct { Metrics *hostMetrics // Err is the error in the query. - // It only tracks network errors or errors of bad cassandra syntax, in particular selects with no match return nil error + // It only tracks network internal_errors or internal_errors of bad cassandra syntax, in particular selects with no match return nil error Err error // Attempt is the index of attempt at executing this query. @@ -2199,7 +2191,7 @@ type ObservedQuery struct { type QueryObserver interface { // ObserveQuery gets called on every query to cassandra, including all queries in an iterator when paging is enabled. // It doesn't get called if there is no query because the session is closed or there are no connections available. - // The error reported only shows query errors, i.e. if a SELECT is valid but finds no matches it will be nil. + // The error reported only shows query internal_errors, i.e. if a SELECT is valid but finds no matches it will be nil. ObserveQuery(context.Context, ObservedQuery) } @@ -2219,7 +2211,7 @@ type ObservedBatch struct { Host *HostInfo // Err is the error in the batch query. - // It only tracks network errors or errors of bad cassandra syntax, in particular selects with no match return nil error + // It only tracks network internal_errors or internal_errors of bad cassandra syntax, in particular selects with no match return nil error Err error // The metrics per this host @@ -2235,7 +2227,7 @@ type BatchObserver interface { // ObserveBatch gets called on every batch query to cassandra. // It also gets called once for each query in a batch. // It doesn't get called if there is no query because the session is closed or there are no connections available. - // The error reported only shows query errors, i.e. if a SELECT is valid but finds no matches it will be nil. + // The error reported only shows query internal_errors, i.e. if a SELECT is valid but finds no matches it will be nil. // Unlike QueryObserver.ObserveQuery it does no reporting on rows read. ObserveBatch(context.Context, ObservedBatch) } @@ -2279,12 +2271,6 @@ var ( ErrNoMetadata = errors.New("no metadata available") ) -type ErrProtocol struct{ error } - -func NewErrProtocol(format string, args ...interface{}) error { - return ErrProtocol{fmt.Errorf(format, args...)} -} - // BatchSizeMaximum is the maximum number of statements a batch operation can have. // This limit is set by cassandra and could change in the future. const BatchSizeMaximum = 65535 diff --git a/uuid.go b/uuid.go index cc5f1c21f..7df3afe9a 100644 --- a/uuid.go +++ b/uuid.go @@ -31,27 +31,18 @@ package gocql import ( "crypto/rand" - "errors" - "fmt" + "github.com/gocql/gocql/internal/protocol" "io" "net" - "strings" "sync/atomic" "time" ) -type UUID [16]byte +type UUID = protocol.UUID var hardwareAddr []byte var clockSeq uint32 -const ( - VariantNCSCompat = 0 - VariantIETF = 2 - VariantMicrosoft = 6 - VariantFuture = 7 -) - func init() { if interfaces, err := net.Interfaces(); err == nil { for _, i := range interfaces { @@ -79,42 +70,9 @@ func init() { clockSeq = uint32(clockSeqRand[1])<<8 | uint32(clockSeqRand[0]) } -// ParseUUID parses a 32 digit hexadecimal number (that might contain hypens) -// representing an UUID. -func ParseUUID(input string) (UUID, error) { - var u UUID - j := 0 - for _, r := range input { - switch { - case r == '-' && j&1 == 0: - continue - case r >= '0' && r <= '9' && j < 32: - u[j/2] |= byte(r-'0') << uint(4-j&1*4) - case r >= 'a' && r <= 'f' && j < 32: - u[j/2] |= byte(r-'a'+10) << uint(4-j&1*4) - case r >= 'A' && r <= 'F' && j < 32: - u[j/2] |= byte(r-'A'+10) << uint(4-j&1*4) - default: - return UUID{}, fmt.Errorf("invalid UUID %q", input) - } - j += 1 - } - if j != 32 { - return UUID{}, fmt.Errorf("invalid UUID %q", input) - } - return u, nil -} - -// UUIDFromBytes converts a raw byte slice to an UUID. -func UUIDFromBytes(input []byte) (UUID, error) { - var u UUID - if len(input) != 16 { - return u, errors.New("UUIDs must be exactly 16 bytes long") - } +var ParseUUID = protocol.ParseUUID - copy(u[:], input) - return u, nil -} +var UUIDFromBytes = protocol.UUIDFromBytes func MustRandomUUID() UUID { uuid, err := RandomUUID() @@ -139,7 +97,7 @@ func RandomUUID() (UUID, error) { return u, nil } -var timeBase = time.Date(1582, time.October, 15, 0, 0, 0, 0, time.UTC).Unix() +var timeBase = protocol.TimeBase // getTimestamp converts time to UUID (version 1) timestamp. // It must be an interval of 100-nanoseconds since timeBase. @@ -227,118 +185,3 @@ func TimeUUIDWith(t int64, clock uint32, node []byte) UUID { // String returns the UUID in it's canonical form, a 32 digit hexadecimal // number in the form of xxxxxxxx-xxxx-xxxx-xxxx-xxxxxxxxxxxx. -func (u UUID) String() string { - var offsets = [...]int{0, 2, 4, 6, 9, 11, 14, 16, 19, 21, 24, 26, 28, 30, 32, 34} - const hexString = "0123456789abcdef" - r := make([]byte, 36) - for i, b := range u { - r[offsets[i]] = hexString[b>>4] - r[offsets[i]+1] = hexString[b&0xF] - } - r[8] = '-' - r[13] = '-' - r[18] = '-' - r[23] = '-' - return string(r) - -} - -// Bytes returns the raw byte slice for this UUID. A UUID is always 128 bits -// (16 bytes) long. -func (u UUID) Bytes() []byte { - return u[:] -} - -// Variant returns the variant of this UUID. This package will only generate -// UUIDs in the IETF variant. -func (u UUID) Variant() int { - x := u[8] - if x&0x80 == 0 { - return VariantNCSCompat - } - if x&0x40 == 0 { - return VariantIETF - } - if x&0x20 == 0 { - return VariantMicrosoft - } - return VariantFuture -} - -// Version extracts the version of this UUID variant. The RFC 4122 describes -// five kinds of UUIDs. -func (u UUID) Version() int { - return int(u[6] & 0xF0 >> 4) -} - -// Node extracts the MAC address of the node who generated this UUID. It will -// return nil if the UUID is not a time based UUID (version 1). -func (u UUID) Node() []byte { - if u.Version() != 1 { - return nil - } - return u[10:] -} - -// Clock extracts the clock sequence of this UUID. It will return zero if the -// UUID is not a time based UUID (version 1). -func (u UUID) Clock() uint32 { - if u.Version() != 1 { - return 0 - } - - // Clock sequence is the lower 14bits of u[8:10] - return uint32(u[8]&0x3F)<<8 | uint32(u[9]) -} - -// Timestamp extracts the timestamp information from a time based UUID -// (version 1). -func (u UUID) Timestamp() int64 { - if u.Version() != 1 { - return 0 - } - return int64(uint64(u[0])<<24|uint64(u[1])<<16| - uint64(u[2])<<8|uint64(u[3])) + - int64(uint64(u[4])<<40|uint64(u[5])<<32) + - int64(uint64(u[6]&0x0F)<<56|uint64(u[7])<<48) -} - -// Time is like Timestamp, except that it returns a time.Time. -func (u UUID) Time() time.Time { - if u.Version() != 1 { - return time.Time{} - } - t := u.Timestamp() - sec := t / 1e7 - nsec := (t % 1e7) * 100 - return time.Unix(sec+timeBase, nsec).UTC() -} - -// Marshaling for JSON -func (u UUID) MarshalJSON() ([]byte, error) { - return []byte(`"` + u.String() + `"`), nil -} - -// Unmarshaling for JSON -func (u *UUID) UnmarshalJSON(data []byte) error { - str := strings.Trim(string(data), `"`) - if len(str) > 36 { - return fmt.Errorf("invalid JSON UUID %s", str) - } - - parsed, err := ParseUUID(str) - if err == nil { - copy(u[:], parsed[:]) - } - - return err -} - -func (u UUID) MarshalText() ([]byte, error) { - return []byte(u.String()), nil -} - -func (u *UUID) UnmarshalText(text []byte) (err error) { - *u, err = ParseUUID(string(text)) - return -}