diff --git a/frame.go b/frame.go index d374ae574..54ad6776f 100644 --- a/frame.go +++ b/frame.go @@ -192,6 +192,8 @@ const ( type Consistency uint16 +type SerialConsistency = Consistency + const ( Any Consistency = 0x00 One Consistency = 0x01 @@ -202,6 +204,8 @@ const ( LocalQuorum Consistency = 0x06 EachQuorum Consistency = 0x07 LocalOne Consistency = 0x0A + Serial Consistency = 0x08 + LocalSerial Consistency = 0x09 ) func (c Consistency) String() string { @@ -224,6 +228,10 @@ func (c Consistency) String() string { return "EACH_QUORUM" case LocalOne: return "LOCAL_ONE" + case Serial: + return "SERIAL" + case LocalSerial: + return "LOCAL_SERIAL" default: return fmt.Sprintf("UNKNOWN_CONS_0x%x", uint16(c)) } @@ -253,6 +261,10 @@ func (c *Consistency) UnmarshalText(text []byte) error { *c = EachQuorum case "LOCAL_ONE": *c = LocalOne + case "SERIAL": + *c = Serial + case "LOCAL_SERIAL": + *c = LocalSerial default: return fmt.Errorf("invalid consistency %q", string(text)) } @@ -260,6 +272,10 @@ func (c *Consistency) UnmarshalText(text []byte) error { return nil } +func (c Consistency) IsSerial() bool { + return c == Serial || c == LocalSerial + +} func ParseConsistency(s string) Consistency { var c Consistency if err := c.UnmarshalText([]byte(strings.ToUpper(s))); err != nil { @@ -286,41 +302,6 @@ func MustParseConsistency(s string) (Consistency, error) { return c, nil } -type SerialConsistency uint16 - -const ( - Serial SerialConsistency = 0x08 - LocalSerial SerialConsistency = 0x09 -) - -func (s SerialConsistency) String() string { - switch s { - case Serial: - return "SERIAL" - case LocalSerial: - return "LOCAL_SERIAL" - default: - return fmt.Sprintf("UNKNOWN_SERIAL_CONS_0x%x", uint16(s)) - } -} - -func (s SerialConsistency) MarshalText() (text []byte, err error) { - return []byte(s.String()), nil -} - -func (s *SerialConsistency) UnmarshalText(text []byte) error { - switch string(text) { - case "SERIAL": - *s = Serial - case "LOCAL_SERIAL": - *s = LocalSerial - default: - return fmt.Errorf("invalid consistency %q", string(text)) - } - - return nil -} - const ( apacheCassandraTypePrefix = "org.apache.cassandra.db.marshal." ) diff --git a/session.go b/session.go index a600b95f3..370d9bb62 100644 --- a/session.go +++ b/session.go @@ -1265,6 +1265,9 @@ func (q *Query) Bind(v ...interface{}) *Query { // SERIAL. This option will be ignored for anything else that a // conditional update/insert. func (q *Query) SerialConsistency(cons SerialConsistency) *Query { + if !cons.IsSerial() { + panic("Serial consistency can only be SERIAL or LOCAL_SERIAL got " + cons.String()) + } q.serialCons = cons return q } @@ -1915,6 +1918,9 @@ func (b *Batch) Size() int { // // Only available for protocol 3 and above func (b *Batch) SerialConsistency(cons SerialConsistency) *Batch { + if !cons.IsSerial() { + panic("Serial consistency can only be SERIAL or LOCAL_SERIAL got " + cons.String()) + } b.serialCons = cons return b }