package client // import "github.com/firepear/petrel/client"
// Copyright (c) 2014-2025 Shawn Boyette <shawn@firepear.net>. All
// rights reserved. Use of this source code is governed by a
// BSD-style license that can be found in the LICENSE file.
// This file implements the Petrel client.
import (
"crypto/tls"
"fmt"
"net"
"time"
p "github.com/firepear/petrel"
)
// Client is a Petrel client instance.
type Client struct {
Resp *p.Resp
conn *p.Conn
// conn closed semaphore
cc bool
}
// Config holds values to be passed to the client constructor.
type Config struct {
// Address is either an IPv4 or IPv6 address followed by the
// desired port number ("127.0.0.1:9090", "[::1]:9090").
Addr string
// Timeout is the number of milliseconds the client will wait
// before timing out due to on a Dispatch() or Read()
// call. Default is no timeout (zero).
Timeout int64
// TLS is the (optional) TLS configuration. If it is nil, the
// connection will be unencrypted.
TLS *tls.Config
// Xferlim is the maximum number of bytes in a single read
// from the network (functionally it limits request or
// response payload size). If a read exceeds this limit,
// the connection will be dropped. Use this to prevent memory
// exhaustion by arbitrarily long network reads. The default
// (0) is unlimited.
Xferlim uint32
//HMACKey is the secret key used to generate MACs for signing
//and verifying messages. Default (nil) means MACs will not be
//generated for messages sent, or expected for messages
//received.
HMACKey []byte
}
// New returns a new Client, configured and ready to use.
func New(c *Config) (*Client, error) {
var conn net.Conn
var err error
if c.TLS == nil {
conn, err = net.Dial("tcp", c.Addr)
} else {
conn, err = tls.Dial("tcp", c.Addr, c.TLS)
}
if err != nil {
return nil, err
}
pconn := &p.Conn{
NC: conn,
Plim: c.Xferlim,
Hkey: c.HMACKey,
Timeout: time.Duration(c.Timeout) * time.Millisecond,
}
client := &Client{&pconn.Resp, pconn, false}
err = client.Dispatch("PROTOCHECK", p.Proto)
if err != nil {
client.Quit()
return nil, err
}
if client.Resp.Status > 200 {
client.Quit()
if client.Resp.Status == 400 {
return nil, fmt.Errorf("[400] PROTOCHECK unsupported")
}
if client.Resp.Status == 497 {
return nil, fmt.Errorf("[497] %s client v%d; server v%d",
p.Stats[497].Txt,
p.Proto[0], client.Resp.Payload[0])
}
return nil, fmt.Errorf("status %d %s", client.Resp.Status,
p.Stats[client.Resp.Status].Txt)
}
return client, nil
}
// Dispatch sends a request and places the response in Client.Resp. If
// Resp.Status has a level of Error or Fatal, the Client will close
// its network connection
func (c *Client) Dispatch(req string, payload []byte) error {
// if a previous error closed the conn, refuse to do anything
if c.cc {
return fmt.Errorf("%d network conn closed; please create a new Client",
c.Resp.Status)
}
// check for cmd length
if len(req) > 255 {
return fmt.Errorf("invalid request: '%s' > 255 bytes", req)
}
// increment sequence
c.conn.Seq++
// send data
err := p.ConnWrite(c.conn, []byte(req), payload)
if err != nil {
return fmt.Errorf("failed to send request '%s'", err)
}
// read response
err = p.ConnRead(c.conn)
// if our response status is Error, close the connection and
// flag ourselves as done
if c.Resp.Status <= 1024 && p.Stats[c.Resp.Status].Lvl == "Error" {
c.Quit()
}
return err
}
// Quit terminates the client's network connection and other
// operations.
func (c *Client) Quit() {
c.cc = true
c.conn.NC.Close()
}
// Copyright (c) 2014-2025 Shawn Boyette <shawn@firepear.net>. All
// rights reserved. Use of this source code is governed by a
// BSD-style license that can be found in the LICENSE file.
package petrel
import (
"crypto/hmac"
"crypto/sha256"
"encoding/base64"
"encoding/binary"
"fmt"
"io"
"net"
"time"
)
// Resp is a packaged response, received from a Conn
type Resp struct {
Status uint16
Req string
Payload []byte
}
// Conn is a network connection plus associated per-connection data.
type Conn struct {
// Id; formerly cn (connection number). ignored for clients
Id string
// Short Id
Sid string
// Message sequence counter
Seq uint32
// transmission header buffer
hb []byte
// pmac stores the HMAC256
pmac []byte
// net.Conn, like it says on the tin
NC net.Conn
// Response struct
Resp Resp
// Payload length limit
Plim uint32
// Network timeout
Timeout time.Duration
// HMAC key
Hkey []byte
// Msg channel
Msgr chan *Msg
}
// ConnRead reads a transmission from a connection.
func ConnRead(c *Conn) error {
if cap(c.hb) != 11 {
c.hb = make([]byte, 11)
}
if c.Timeout > 0 {
err := c.NC.SetReadDeadline(time.Now().Add(c.Timeout))
if err != nil {
c.Resp.Status = 498
return err
}
}
// read the transmission header
n, err := c.NC.Read(c.hb)
if err != nil {
if err == io.EOF {
c.Resp.Status = 198 // (probably) clean disconnect
return err
}
c.Resp.Status = 498 // read err
return fmt.Errorf("%s: no xmission header: %v", Stats[498].Txt, err)
}
if n != cap(c.hb) {
c.Resp.Status = 498 // read errbinary.LittleEndian.Uint16(c.hb[0:1])
return fmt.Errorf("%s: short read on xmission header", Stats[498].Txt)
}
// get data from header, beginning with status
c.Resp.Status = binary.LittleEndian.Uint16(c.hb[0:2])
// sequence id
c.Seq = binary.LittleEndian.Uint32(c.hb[2:6])
// request length
rlen := uint8(c.hb[6])
// payload length
plen := binary.LittleEndian.Uint32(c.hb[7:])
// read and decode the request. we do this before erroring if
// plen is over limit, so that Req will be set properly in
// logging and the reply
req := make([]byte, rlen)
n, err = c.NC.Read(req)
if err != nil {
if err == io.EOF {
c.Resp.Status = 198 // (probably) clean disconnect
return err
}
c.Resp.Status = 498 // read err
return fmt.Errorf("%s: couldn't read request: %v", Stats[498].Txt, err)
}
if uint8(n) != rlen {
c.Resp.Status = 498 // read err
return fmt.Errorf("%s: short read on request; expected %d bytes got %d",
Stats[498].Txt, rlen, n)
}
c.Resp.Req = string(req)
// reject the request if plen exceeds xfer limit
if c.Plim != 0 && plen > c.Plim {
c.Resp.Status = 402 // declared payload over lemgth limit
return fmt.Errorf("%d > %d", plen, c.Plim)
}
// setup to read payload
// network read buffer
b1 := make([]byte, 128)
// transmission accumulation buffer
b2 := []byte{}
// accumulated bytes read
bread := uint32(0)
// now read the payload
for bread < plen {
// if there are less than 128 bytes remaining to read
// in the payload, resize b1 to fit. this avoids
// reading across a transmission boundary.
if x := plen - bread; x < 128 {
b1 = make([]byte, x)
}
if c.Timeout > 0 {
err = c.NC.SetReadDeadline(time.Now().Add(c.Timeout))
if err != nil {
c.Resp.Status = 498
return err
}
}
n, err = c.NC.Read(b1)
if err != nil {
if err == io.EOF {
c.Resp.Status = 198
return err
}
c.Resp.Status = 498 // read err
return err
}
bread += uint32(n)
if c.Plim > 0 && bread > c.Plim {
c.Resp.Status = 402 // (actual) payload over length limit
return fmt.Errorf("%d bytes", bread)
}
// it's easier to append everything, every time.
// overrun is handled as soon as we stop reading
b2 = append(b2, b1...)
}
// truncate payload accumulator at payload length and store as
// the response payload
c.Resp.Payload = b2[:plen]
// finally, if we have a MAC, read and verify it
if c.Hkey != nil {
if c.Timeout > 0 {
err = c.NC.SetReadDeadline(time.Now().Add(c.Timeout))
if err != nil {
c.Resp.Status = 498
return err
}
}
n, err = c.NC.Read(c.pmac)
if err != nil {
if err == io.EOF {
c.Resp.Status = 198 // (probably) clean disconnect
return err
}
c.Resp.Status = 498 // read err
return err
}
if n != 44 {
c.Resp.Status = 498 // read err
return fmt.Errorf("bad read on HMAC: %db != 44", n)
}
mac := hmac.New(sha256.New, c.Hkey)
mac.Write(b2)
computedMAC := make([]byte, 44)
base64.StdEncoding.Encode(computedMAC, mac.Sum(nil))
if !hmac.Equal(c.pmac, computedMAC) {
c.Resp.Status = 502 // hmac failure
return fmt.Errorf("%v", Stats[502])
}
}
return err
}
// ConnWrite writes a message to a connection.
func ConnWrite(c *Conn, request, payload []byte) error {
if c.Timeout > 0 {
err := c.NC.SetReadDeadline(time.Now().Add(c.Timeout))
if err != nil {
c.Resp.Status = 498
return err
}
}
_, err := c.NC.Write(marshalXmission(c, request, payload))
if err != nil {
// overloading response, but eh
c.Resp.Status = 499 // write error
}
return err
}
// marshalXmission marshals a Msg payload into a wire-formatted
// transmission.
func marshalXmission(c *Conn, request, payload []byte) []byte {
xmission := make([]byte, 11)
// status
binary.LittleEndian.PutUint16(xmission[0:], c.Resp.Status)
// seq
binary.LittleEndian.PutUint32(xmission[2:], c.Seq)
// encode request length
xmission[6] = uint8(len(request))
// encode payload length
binary.LittleEndian.PutUint32(xmission[7:], uint32(len(payload)))
// append request and payload
xmission = append(xmission, request...)
xmission = append(xmission, payload...)
// handle HMAC if needed
if c.Hkey != nil {
mac := hmac.New(sha256.New, c.Hkey)
mac.Write(payload)
macb64 := make([]byte, 44)
base64.StdEncoding.Encode(macb64, mac.Sum(nil))
xmission = append(xmission, macb64...)
}
return xmission
}
package server
// Copyright (c) 2014-2025 Shawn Boyette <shawn@firepear.net>. All
// rights reserved. Use of this source code is governed by a
// BSD-style license that can be found in the LICENSE file.
// Socket code for petrel
import (
"fmt"
"time"
p "github.com/firepear/petrel"
)
// sockAccept is spawned by server.commonNew. It monitors the server's
// listener socket and spawns connections for clients.
func (s *Server) sockAccept() {
defer s.w.Done()
for {
// we wait here until the listener accepts a
// connection and spawns us a petrel.Conn -- or an
// error occurs, like the listener socket closing
id, sid := p.GenId()
pc := &p.Conn{Id: id, Sid: sid, Msgr: s.Msgr}
nc, err := s.l.Accept()
if err != nil {
select {
case <-s.q:
// if there's a message on this
// channel, s.Quit() was invoked and
// we should close up shop
s.Msgr <- &p.Msg{Cid: pc.Sid, Seq: pc.Seq, Req: "NONE",
Code: 199, Txt: "err is spurious", Err: err}
return
default:
// otherwise, we've had an actual
// networking error
s.Msgr <- &p.Msg{Cid: pc.Sid, Seq: pc.Seq, Req: pc.Resp.Req,
Code: 599, Txt: "unknown err", Err: err}
return
}
}
// we made it here so we have a new connection. wrap
// our net.Conn in a petrel.Conn for parity with the
// common netcode then add other values
pc.NC = nc
pc.Plim = s.rl
pc.Hkey = s.hk
pc.Timeout = time.Duration(s.t) * time.Millisecond
// increment our waitgroup
s.w.Add(1)
// add to connlist
s.cl.Store(id, pc)
// and launch the goroutine which will actually
// service the client
go s.connServer(pc)
}
}
// connServer dispatches commands from, and sends reponses to, a
// client. It is launched, per-connection, from sockAccept().
func (s *Server) connServer(c *p.Conn) {
// queue up decrementing the waitlist, closing the network
// connection, and removing the connlist entry
defer s.w.Done()
defer c.NC.Close()
defer s.cl.Delete(c.Id)
c.Msgr <- &p.Msg{Cid: c.Sid, Seq: c.Seq, Req: c.Resp.Req, Code: 100,
Txt: fmt.Sprintf("srv:%s %s %s", s.sid, p.Stats[100].Txt,
c.NC.RemoteAddr().String()),
Err: nil}
var response []byte
for {
// let us forever enshrine the dumbness of the
// original design of the network read/write
// functions, that we may never see their like again:
//
// req, payload, perr, xtra, err := p.ConnRead(c, s.t, s.rl, s.hk, &reqid)
// perr, err = p.ConnWrite(c, req, p.Stats[perr].Xmit, s.hk, s.t, reqid)
// read the request
err := p.ConnRead(c)
if err != nil || c.Resp.Status > 399 {
c.Msgr <- &p.Msg{Cid: c.Sid, Seq: c.Seq, Req: c.Resp.Req,
Code: c.Resp.Status, Txt: p.Stats[c.Resp.Status].Txt,
Err: err}
// don't care about err here because we're
// gonna bail, and this may not work anyway
_ = p.ConnWrite(c, []byte(c.Resp.Req),
[]byte(fmt.Sprintf("%s", err)))
break
}
// lookup the handler for this request
handler, ok := s.d[c.Resp.Req]
if ok {
// dispatch the request and get the response
c.Resp.Status, response, err = handler(c.Resp.Payload)
if err != nil {
c.Resp.Status = 500
}
} else {
// unknown handler
c.Resp.Status = 400
}
// we always send a response
err = p.ConnWrite(c, []byte(c.Resp.Req), response)
if c.Resp.Status > 1024 {
c.Msgr <- &p.Msg{Cid: c.Sid, Seq: c.Seq, Req: c.Resp.Req,
Code: c.Resp.Status, Txt: "app defined code", Err: err}
} else {
c.Msgr <- &p.Msg{Cid: c.Sid, Seq: c.Seq, Req: c.Resp.Req,
Code: c.Resp.Status, Txt: p.Stats[c.Resp.Status].Txt,
Err: err}
}
if err != nil {
break
}
}
}
package server
// Copyright (c) 2014-2025 Shawn Boyette <shawn@firepear.net>. All
// rights reserved. Use of this source code is governed by a
// BSD-style license that can be found in the LICENSE file.
import (
"crypto/tls"
"fmt"
"log/slog"
"net"
"os"
"sync"
"time"
p "github.com/firepear/petrel"
)
// Server is a Petrel server instance.
type Server struct {
// Msgr is the internal-facing channel which receives
// notifications from connections.
Msgr chan *p.Msg
// Shutdown is the external-facing channel which notifies
// applications that a Server instance is shutting down
Shutdown chan error
id string // server id
sid string // short id
q chan bool // quit signal socket
s string // socket name
l net.Listener // listener socket
log *slog.Logger // Logger instance
d map[string]Handler // dispatch table
cl *sync.Map // connection list
t time.Duration // timeout
rl uint32 // request length
hk []byte // HMAC key
w *sync.WaitGroup
logd map[string]func(string, ...any)
}
// Config holds values to be passed to server constuctors.
type Config struct {
// Addr is the IP+port of the socket, e.g."127.0.0.1:9090"
// or "[::1]:9090".
Addr string
// TLS is a crypto/tls configuration struct. If it is present,
// then the server will be TLS-enabled.
TLS *tls.Config
// Logger is the logging instance which will be used to handle
// messages. The default is a slog.TextHandler that writes to
// stdout, with a logging level of Debug
Logger *slog.Logger
// Timeout is the number of milliseconds the Server will wait
// when performing network ops before timing out. Default
// (zero) is no timeout. Each connection to the server is
// handled in a separate goroutine, however, so one blocked
// connection does not affect any others (unless you run out of
// file descriptors for new conns).
Timeout int64
// Xferlim is the maximum number of bytes in a single read
// from the network. If a request exceeds this limit, the
// connection will be dropped. Use this to prevent memory
// exhaustion by arbitrarily long network reads. The default
// (0) is unlimited. The message header counts toward the
// limit, so very small limits or payloads that bump up
// against the limit may cause unexpected failures.
Xferlim uint32
// HMACKey is the secret key used to generate MACs for signing
// and verifying messages. Default (nil) means MACs will not
// be generated for messages sent, or expected for messages
// received. Enabling message authentication adds significant
// overhead for each message sent and received, so use this
// when security outweighs performance.
HMACKey []byte
// Buffer sets how many instances of Msg may be queued in
// Server.Msgr. Non-Fatal Msgs which arrive while the buffer
// is full are dropped on the floor to prevent the Server from
// blocking. Defaults to 64.
Buffer int
}
// Handler is the type which functions passed to Server.Register must
// match: taking a slice of bytes as an argument; and returning a
// uint16 (indicating status), a slice of bytes (the response), and an
// error.
//
// Petrel reserves the status range 1-2048 for internal
// use. Applications may use codes in this range, but the system will
// interpret them according to their defined meanings (e.g. it is
// standard to return '200' for success with no additional
// context). Applications are free define the remaining codes, up to
// 65535, as they see fit.
type Handler func([]byte) (uint16, []byte, error)
// New returns a new Server, ready to have handlers added.
func New(c *Config) (*Server, error) {
var l net.Listener
var err error
// create our listener
if c.TLS != nil {
l, err = tls.Listen("tcp", c.Addr, c.TLS)
} else {
tcpaddr, _ := net.ResolveTCPAddr("tcp", c.Addr)
l, err = net.ListenTCP("tcp", tcpaddr)
}
if err != nil {
return nil, err
}
// set c.Buffer to the default if it's zero
if c.Buffer == 0 {
c.Buffer = 64
}
// set logger if one was not provided
if c.Logger == nil {
c.Logger = slog.New(slog.NewTextHandler(
os.Stdout, &slog.HandlerOptions{Level: slog.LevelDebug}))
}
// generate id and short id
id, sid := p.GenId()
// create the Server, start listening, and return
s := &Server{
Msgr: make(chan *p.Msg, c.Buffer),
Shutdown: make(chan error, 4),
q: make(chan bool, 1),
d: make(map[string]Handler),
logd: make(map[string]func(string, ...any), 5),
id: id,
sid: sid,
s: c.Addr,
l: l,
log: c.Logger,
cl: &sync.Map{},
t: time.Duration(c.Timeout) * time.Millisecond,
rl: c.Xferlim,
hk: c.HMACKey,
w: &sync.WaitGroup{},
}
// add one to waitgroup for s.sockAccept()
s.w.Add(1)
// populate logging dispatch table
s.logd["Debug"] = s.log.Debug
s.logd["Info"] = s.log.Info
s.logd["Warn"] = s.log.Warn
s.logd["Error"] = s.log.Error
// start msgHandler event func
go msgHandler(s)
// launch the listener socket event func
go s.sockAccept()
// register the PROTOCHECK handler, called by all clients
// during connection
err = s.Register("PROTOCHECK", protocheck)
if err == nil {
s.log.Debug("petrel server up", "sid", s.sid, "addr", c.Addr)
}
// all done
return s, err
}
// Register adds a Handler function to a Server.
//
// 'name' is the command you wish this function to be the responder
// for.
//
// 'r' is the name of the Handler function which will be called on dispatch.
func (s *Server) Register(name string, r Handler) error {
if _, ok := s.d[name]; ok {
return fmt.Errorf("handler '%s' already exists", name)
}
s.d[name] = r
return nil
}
// Quit handles shutdown and cleanup, including waiting for any
// connections to terminate. When it returns, all connections are
// fully shut down and no more work will be done.
func (s *Server) Quit() {
s.q <- true // send true to quit chan
s.l.Close() // close listener
s.w.Wait() // wait for waitgroup to turn down
close(s.q)
close(s.Msgr)
}
// msgHandler is a function which we'll launch later on as a
// goroutine. It listens to our Server's Msgr channel, checking for a
// few critical things and logging everything else informationally.
func msgHandler(s *Server) {
keepalive := true
for keepalive {
msg := <-s.Msgr
switch msg.Code {
case 599:
// 599 is "the Server listener socket has
// died". call s.Quit() to clean things up,
// send the Msg to our main routine, then kill
// this loop
s.Shutdown <- msg
keepalive = false
s.Quit()
case 199:
// 199 is "we've been told to quit", so we
// want to break out of the loop here as well
s.Shutdown <- msg
keepalive = false
default:
// anything else we'll log
if msg.Code <= 1024 {
s.logd[p.Stats[msg.Code].Lvl](msg.Txt,
"code", msg.Code,
"desc", p.Stats[msg.Code].Txt,
"req", msg.Req,
"cid", msg.Cid,
"err", msg.Err)
} else {
s.logd["Info"](msg.Txt,
"code", msg.Code,
"req", msg.Req,
"cid", msg.Cid,
"err", msg.Err)
}
}
}
}
// protocheck implements the mandatory protocol check handler
func protocheck(proto []byte) (uint16, []byte, error) {
if proto[0] == p.Proto[0] {
return 200, p.Proto, nil
}
return 497, p.Proto, nil
}
//go:build testing
package server
// RemoveHandler allows the removal of Handlers from the server
// dispatch table. As its purpose is to allow testing of
// error-handling within the client, it is only compiled in and
// available when the `testing` build flag is provided.
func (s *Server) RemoveHandler(name string) bool {
delete(s.d, name)
_, ok := s.d[name]
return !ok
}
// Copyright (c) 2014-2025 Shawn Boyette <shawn@firepear.net>. All
// rights reserved. Use of this source code is governed by a
// BSD-style license that can be found in the LICENSE file.
package petrel
import (
"crypto/sha256"
"fmt"
"strconv"
"time"
)
// Msg is the format which Petrel uses to communicate informational
// messages and errors to its host program via the s.Msgr channel.
type Msg struct {
// Cid is the connection ID that the Msg is coming from.
Cid string
// Seq is the request number that resulted in the Msg.
Seq uint32
// Req is the request made
Req string
// Status is the numeric status indicator.
Code uint16
// Txt is free-form informational content
Txt string
// Err is the error (if any) passed upward as part of the Msg.
Err error
}
// Error implements the error interface for Msg, returning a nicely
// (if blandly) formatted string containing all information present.
func (m *Msg) Error() string {
if m.Code <= 1024 {
if m.Err != nil {
return fmt.Sprintf("c:%s r:%d (%s, %d, %s) %s : %s",
m.Cid, m.Seq, m.Req, m.Code, Stats[m.Code].Txt, m.Txt, m.Err)
} else {
return fmt.Sprintf("c:%s r:%d (%s, %d, %s) %s",
m.Cid, m.Seq, m.Req, m.Code, Stats[m.Code].Txt, m.Txt)
}
} else {
if m.Err != nil {
return fmt.Sprintf("c:%s r:%d (%s, %d) %s : %s",
m.Cid, m.Seq, m.Req, m.Code, m.Txt, m.Err)
} else {
return fmt.Sprintf("c:%s r:%d (%s, %d) %s",
m.Cid, m.Seq, m.Req, m.Code, m.Txt)
}
}
}
// GenId generates a SHA256 hash of the current Unix time, in
// nanoseconds. It then returns the hexadecimal string representation
// of this hash, and a "short" hash (the first 8 characters of the hex
// string, much as git does with commit hashes)
func GenId() (string, string) {
h := sha256.Sum256([]byte(strconv.FormatInt(time.Now().UnixNano(), 16)))
return fmt.Sprintf("%x", h), fmt.Sprintf("%x", h[:4])
}