#include #include #include #include #include #include #include #include #include #include #include #include #include #include #include "../rpca.gen.h" // ----------------------------------------------------------------------------- // ----------------------------------------------------------------------------- // BUFFERS: struct buffer { uint8_t *data; unsigned length; }; static void buffer_free_data(struct buffer *const buf) { free(buf->data); } // ----------------------------------------------------------------------------- // ----------------------------------------------------------------------------- // ----------------------------------------------------------------------------- // ----------------------------------------------------------------------------- // CONNECTIONS: struct connection { int fd; unsigned state; struct bufferevent *be; struct sockaddr_in sin; struct event connect_event; void (*state_cb) (struct connection *, void *); void (*read_cb) (struct connection *, void *); void *arg; }; // These are the values of the @state member of a struct connection enum { kConnStateResolving, kConnStateConnecting, kConnStateUp, kConnStateDead }; // ----------------------------------------------------------------------------- // This is called when something is fatally wrong with the connection // ----------------------------------------------------------------------------- static void connection_fatal(struct connection *const conn) { close(conn->fd); conn->state = kConnStateDead; conn->state_cb(conn, conn->arg); } // ----------------------------------------------------------------------------- // A callback which is called when there's data in the bufferevent to read // ----------------------------------------------------------------------------- static void connection_read(struct bufferevent *const be, void *arg) { struct connection *const conn = (struct connection *) arg; conn->read_cb(conn, conn->arg); } // ----------------------------------------------------------------------------- // A callback which is called when the bufferevent's outbound queue is empty. // ----------------------------------------------------------------------------- static void connection_write(struct bufferevent *const be, void *arg) { // * Stop trying to write struct connection *const conn = (struct connection *) arg; bufferevent_enable(conn->be, EV_READ); } // ----------------------------------------------------------------------------- // This callback is made when there's an error on the fd. // ----------------------------------------------------------------------------- static void connection_error(struct bufferevent *const be, short what, void *arg) { struct connection *const conn = (struct connection *) arg; connection_fatal(conn); } // ----------------------------------------------------------------------------- // Called when the connection's fd is ready // ----------------------------------------------------------------------------- static void connection_up(struct connection *const conn) { conn->be = bufferevent_new (conn->fd, connection_read, connection_write, connection_error, conn); conn->state = kConnStateUp; conn->state_cb(conn, conn->arg); bufferevent_enable(conn->be, EV_READ); } // ----------------------------------------------------------------------------- // A callback for when the connecting socket is writable. // ----------------------------------------------------------------------------- static void connection_connected(int fd, short event, void *arg) { struct connection *const conn = (struct connection *) arg; int error; socklen_t error_len = sizeof(error); getsockopt(conn->fd, SOL_SOCKET, SO_ERROR, &error, &error_len); if (error) { connection_fatal(conn); } else { connection_up(conn); } } // ----------------------------------------------------------------------------- // Called when conn->sin has already been filled in to start the connection. // ----------------------------------------------------------------------------- static void connection_connect(struct connection *const conn) { const int n = connect(conn->fd, (struct sockaddr *) &conn->sin, sizeof(conn->sin)); if (n == -1) { switch (errno) { case EINPROGRESS: case EWOULDBLOCK: event_set(&conn->connect_event, conn->fd, EV_WRITE, connection_connected, conn); event_add(&conn->connect_event, NULL); return; default: connection_fatal(conn); } } else { connection_up(conn); } } // ----------------------------------------------------------------------------- // This is a callback which is made when a connection's hostname is resolved. // ----------------------------------------------------------------------------- static void connection_resolved(int result, char type, int count, int ttl, void *addresses, void *arg) { struct connection *const conn = (struct connection *) arg; if (result != DNS_ERR_NONE || count < 1) { connection_fatal(conn); return; } memcpy(&conn->sin.sin_addr.s_addr, addresses, 4); connection_connect(conn); } static void connection_free(struct connection *const conn) { close(conn->fd); if (conn->be) bufferevent_free(conn->be); free(conn); } // ----------------------------------------------------------------------------- // Create a new connection to the given host:port pair // ----------------------------------------------------------------------------- static struct connection * connection_new(const char *const host, int port, void (*state_cb) (struct connection *, void *), void (*read_cb) (struct connection *, void *), void *arg) { struct connection *const conn = malloc(sizeof(struct connection)); memset(conn, 0, sizeof(struct connection)); conn->fd = socket(AF_INET, SOCK_STREAM, 0); // set socket non-blocking const long flags = fcntl(conn->fd, F_GETFL); fcntl(conn->fd, F_SETFL, flags | O_NONBLOCK); conn->state_cb = state_cb; conn->read_cb = read_cb; conn->arg = arg; conn->sin.sin_family = PF_INET; conn->sin.sin_port = htons(port); // see if the given hostname is an ip address if (inet_aton(host, &conn->sin.sin_addr)) { // it's an IP address connection_connect(conn); } else { // it's not an IP address, we need to resolve it conn->state = kConnStateResolving; evdns_resolve_ipv4(host, 0, connection_resolved, conn); } return conn; } // ----------------------------------------------------------------------------- // Write some data to the connection // ----------------------------------------------------------------------------- static void connection_enqueue(struct connection *const conn, const struct buffer *const buf) { bufferevent_write(conn->be, buf->data, buf->length); } // ----------------------------------------------------------------------------- // ----------------------------------------------------------------------------- // ----------------------------------------------------------------------------- // ----------------------------------------------------------------------------- // CHANNELS // ----------------------------------------------------------------------------- // This is the information which is kept about each outstanding RPC // ----------------------------------------------------------------------------- typedef void (*rpc_callback) (struct rpcreply *, const uint8_t *body, unsigned len, void *); struct dispatch { rpc_callback cb; void *arg; struct event timeout_event; char timeout_event_enabled; }; static void dispatch_free(struct dispatch *const disp) { if (disp->timeout_event_enabled) { event_del(&disp->timeout_event); } free(disp); } static void prelude_serialise(uint8_t *const out, unsigned header_len, unsigned body_len, unsigned chksum) { uint32_t header_len_be = htonl(header_len); uint32_t body_len_be = htonl(body_len); uint32_t chksum_be = htonl(chksum); memcpy(out, &header_len_be, 4); memcpy(out + 4, &body_len_be, 4); memcpy(out + 8, &chksum_be, 4); } static const unsigned kChanStateUp = 1; static const unsigned kChanStateDown = 2; struct channel { struct connection *conn; char *host; int port; char *service; int state; uint32_t next_id; struct event reconnect_timeout; char reconnect_timeout_enabled; // This is a mapping from RPC id to a pair of pointers to buffers (which is // the header and payload of that RPC). These are requests which are getting // buffered here until we have a connection to send them on. In the case that // they time out, they are removed from this structure. GTree *outq; // This is mapping from RPC id to a dispatch structure GTree *dispatch; }; static void channel_free(struct channel *const chan) { free(chan->host); free(chan->service); g_tree_destroy(chan->outq); g_tree_destroy(chan->dispatch); if (chan->reconnect_timeout_enabled) evtimer_del(&chan->reconnect_timeout); // FIXME: delete connection (if nonnull) } // ----------------------------------------------------------------------------- // This is the key compare function for our dispatch and outq trees. The keys // are pointers with uints packed into them. // ----------------------------------------------------------------------------- static int channel_uint_key_cmp(gconstpointer a, gconstpointer b, gpointer arg) { const guint ua = GPOINTER_TO_UINT(a); const guint ub = GPOINTER_TO_UINT(b); if (ua < ub) { return -1; } else if (ua > ub) { return 1; } else { return 0; } } // ----------------------------------------------------------------------------- // Our keys are uints packed into the pointer, thus we don't need to do // anything to destroy them // ----------------------------------------------------------------------------- static void channel_uint_key_destroy(gpointer key) { } // ----------------------------------------------------------------------------- // Free an outq tree values, which is an alloced array of two pointers to // buffers. // ----------------------------------------------------------------------------- static void channel_outq_value_destroy(gpointer value) { struct buffer *buffers = (struct buffer *) value; buffer_free_data(&buffers[0]); buffer_free_data(&buffers[1]); buffer_free_data(&buffers[2]); free(buffers); } // ----------------------------------------------------------------------------- // This is called when we destroy a dispatch tree to delete all the values. In // a dispatch tree, the values are allocated dispatch objects // ----------------------------------------------------------------------------- static void channel_dispatch_value_destroy(gpointer value) { struct dispatch *const disp = (struct dispatch *) value; dispatch_free(disp); } // ----------------------------------------------------------------------------- // When a connect is made, this function is called for each enqueued block of // data. These blocks of data are rpcrequests (and their payloads) for RPCs // which have been waiting for the connection to come up. // ----------------------------------------------------------------------------- static gboolean channel_block_to_conn(gpointer key, gpointer value, gpointer data) { struct channel *const chan = (struct channel *) data; struct buffer *const bufs = (struct buffer *) value; // * The arg is a pointer to an array of two pointers to buffers connection_enqueue(chan->conn, &bufs[0]); connection_enqueue(chan->conn, &bufs[1]); connection_enqueue(chan->conn, &bufs[2]); return FALSE; } static const unsigned kMaxSerialisedRPCHeaderSize = 512; static const unsigned kMaxSerialisedPayloadSize = 1024 * 1024; static void evbuffer_fake(struct evbuffer *const buf, const uint8_t *header, unsigned length) { memset(buf, 0, sizeof(struct evbuffer)); buf->buffer = (uint8_t *) header; buf->totallen = buf->off = length; } static void channel_have_message(struct channel *const chan, const uint8_t *const header, const unsigned header_len, const uint8_t *const body, const unsigned body_len) { struct inboundreply reply; struct evbuffer buf; memset(&reply, 0, sizeof(reply)); evbuffer_fake(&buf, header, header_len); if (inboundreply_unmarshal(&reply, &buf)) { fprintf(stderr, "Failed to decode RPC reply header"); return; } if (reply.rpc_set) { const gpointer id_as_ptr = GUINT_TO_POINTER(reply.id_data); const gpointer value = g_tree_lookup(chan->dispatch, id_as_ptr); if (!value) { fprintf(stderr, "Got reply for unknown RPC id %d\n", reply.id_data); return; } struct dispatch *const disp = (struct dispatch *) value; disp->cb(reply.rpc_data, body, body_len, disp->arg); if (disp->timeout_event_enabled) { evtimer_del(&disp->timeout_event); disp->timeout_event_enabled = 0; } g_tree_remove(chan->dispatch, id_as_ptr); } } // ----------------------------------------------------------------------------- // This callback is made when there's data in the connection's input buffer // ----------------------------------------------------------------------------- static void channel_conn_read(struct connection *conn, void *arg) { struct channel *const chan = (struct channel *) arg; // We need to read the 12 byte header first fprintf(stderr, "channel_conn_read\n"); for (;;) { if (EVBUFFER_LENGTH(conn->be->input) < 12) return; uint32_t t; memcpy(&t, EVBUFFER_DATA(conn->be->input), sizeof(t)); const uint32_t header_len = ntohl(t); memcpy(&t, EVBUFFER_DATA(conn->be->input) + sizeof(t), sizeof(t)); const uint32_t body_len = ntohl(t); if (header_len > kMaxSerialisedRPCHeaderSize || body_len > kMaxSerialisedPayloadSize) { fprintf(stderr, "Got oversized message: %u %u\n", header_len, body_len); // FIXME: kill channel return; } // * If we don't have the whole message, give up const uint32_t total_len = 12 + header_len + body_len; if (EVBUFFER_LENGTH(conn->be->input) < (12 + header_len + body_len)) return; channel_have_message(chan, EVBUFFER_DATA(conn->be->input) + 12, header_len, EVBUFFER_DATA(conn->be->input) + 12 + header_len, body_len); evbuffer_drain(conn->be->input, total_len); } } static void channel_conn_state(struct connection *conn, void *arg); static struct timeval kReconnectTimeout = {5, 0}; // ----------------------------------------------------------------------------- // This is a timeout callback which is called when we want to reconnect // ----------------------------------------------------------------------------- static void channel_reconnect(int fd, short events, void *arg) { struct channel *const chan = (struct channel *) arg; chan->reconnect_timeout_enabled = 0; connection_free(chan->conn); chan->conn = connection_new(chan->host, chan->port, channel_conn_state, channel_conn_read, chan); } // ----------------------------------------------------------------------------- // This callback is made when the state of the connection changes. // ----------------------------------------------------------------------------- static void channel_conn_state(struct connection *conn, void *arg) { struct channel *const chan = (struct channel *) arg; switch (conn->state) { case kConnStateUp: // the connection has just finished connecting // in this case we need to move all the pending data to the connection chan->state = kChanStateUp; g_tree_foreach(chan->outq, channel_block_to_conn, chan); g_tree_destroy(chan->outq); chan->outq = g_tree_new_full (channel_uint_key_cmp, NULL, channel_uint_key_destroy, channel_outq_value_destroy); break; case kConnStateDead: // the connection has just failed. Try to connect again chan->state = kChanStateDown; evtimer_set(&chan->reconnect_timeout, channel_reconnect, chan); evtimer_add(&chan->reconnect_timeout, &kReconnectTimeout); chan->reconnect_timeout_enabled = 1; break; } } static struct channel * channel_new(const char *const host, int port, const char *const service) { struct channel *const chan = (struct channel *) malloc(sizeof(struct channel)); memset(chan, 0, sizeof(struct channel)); chan->host = strdup(host); chan->port = port; chan->service = strdup(service); chan->conn = connection_new(host, port, channel_conn_state, channel_conn_read, chan); chan->dispatch = g_tree_new_full (channel_uint_key_cmp, NULL, channel_uint_key_destroy, channel_dispatch_value_destroy); chan->outq = g_tree_new_full (channel_uint_key_cmp, NULL, channel_uint_key_destroy, channel_outq_value_destroy); return chan; } static void channel_enqueue(struct channel *const chan, const char *method, rpc_callback cb, void *arg, const uint8_t *body, const unsigned body_len, unsigned timeout) { const uint32_t id = chan->next_id++; struct outboundrequest obr; struct rpcrequest rpcr; memset(&rpcr, 0, sizeof(rpcr)); rpcrequest_service_assign(&rpcr, chan->service); rpcrequest_method_assign(&rpcr, method); memset(&obr, 0, sizeof(obr)); outboundrequest_id_assign(&obr, id); outboundrequest_rpc_assign(&obr, &rpcr); struct evbuffer *buf = evbuffer_new(); outboundrequest_marshal(buf, &obr); const unsigned header_len = EVBUFFER_LENGTH(buf); uint8_t prelude[12]; prelude_serialise(prelude, header_len, body_len, 0); if (chan->state != kChanStateUp) { // * Enqueue the RPC for when we have a connection uint8_t *const header = (uint8_t *) malloc(header_len); memcpy(header, EVBUFFER_DATA(buf), header_len); struct buffer *const buffers = (struct buffer *) malloc(sizeof(struct buffer ) * 3); uint8_t *const prelude_copy = (uint8_t *) malloc(12); memcpy(prelude_copy, prelude, 12); buffers[0].data = prelude_copy; buffers[0].length = 12; buffers[1].data = header; buffers[1].length = header_len; uint8_t *const body_copy = (uint8_t *) malloc(body_len); memcpy(body_copy, body, body_len); buffers[2].data = (uint8_t *) body; buffers[2].length = body_len; g_tree_insert(chan->outq, GUINT_TO_POINTER(id), buffers); } else { // * We have a connection - send the request struct buffer dbuf; dbuf.data = prelude; dbuf.length = 12; connection_enqueue(chan->conn, &dbuf); dbuf.data = EVBUFFER_DATA(buf); dbuf.length = header_len; connection_enqueue(chan->conn, &dbuf); dbuf.data = (uint8_t *) body; dbuf.length = body_len; connection_enqueue(chan->conn, &dbuf); } // * Record the RPC id with the given callback information struct dispatch *const disp = (struct dispatch *) malloc(sizeof(struct dispatch)); memset(disp, 0, sizeof(struct dispatch)); disp->cb = cb; disp->arg = arg; disp->timeout_event_enabled = 0; g_tree_insert(chan->dispatch, GUINT_TO_POINTER(id), disp); evbuffer_free(buf); } // ----------------------------------------------------------------------------- // ----------------------------------------------------------------------------- static void callback(struct rpcreply *reply, const uint8_t *body, unsigned len, void *arg) { fprintf(stderr, "OMG\n"); } int main() { event_init(); evdns_init(); evtag_init(); struct channel *const chan = channel_new("127.0.0.1", 4545, "test"); channel_enqueue(chan, "halfsec", callback, NULL, NULL, 0, 0); event_loop(0); return 0; }