From 4e7d1c3a2d042a62aea0151959d8dbc742d4b542 Mon Sep 17 00:00:00 2001 From: me-no-dev Date: Mon, 24 Jun 2019 10:23:58 +0200 Subject: [PATCH] try to catch some edge cases between LwIP and Async tasks --- src/AsyncTCP.cpp | 121 ++++++++++++++++++------------- src/AsyncTCP.h | 182 +++++++++++++++++++++++++---------------------- 2 files changed, 165 insertions(+), 138 deletions(-) diff --git a/src/AsyncTCP.cpp b/src/AsyncTCP.cpp index e9aba3d..76ef765 100644 --- a/src/AsyncTCP.cpp +++ b/src/AsyncTCP.cpp @@ -328,6 +328,7 @@ static int8_t _tcp_accept(void * arg, AsyncClient * client) { typedef struct { struct tcpip_api_call_data call; tcp_pcb * pcb; + AsyncClient * client; int8_t err; union { struct { @@ -351,34 +352,42 @@ typedef struct { static err_t _tcp_output_api(struct tcpip_api_call_data *api_call_msg){ tcp_api_call_t * msg = (tcp_api_call_t *)api_call_msg; - msg->err = tcp_output(msg->pcb); + msg->err = 0; + if(msg->client && msg->client->pcb()){ + msg->err = tcp_output(msg->pcb); + } return msg->err; } -static esp_err_t _tcp_output(tcp_pcb * pcb) { +static esp_err_t _tcp_output(tcp_pcb * pcb, AsyncClient * client) { if(!pcb){ log_w("pcb is NULL"); return ESP_FAIL; } tcp_api_call_t msg; msg.pcb = pcb; + msg.client = client; tcpip_api_call(_tcp_output_api, (struct tcpip_api_call_data*)&msg); return msg.err; } static err_t _tcp_write_api(struct tcpip_api_call_data *api_call_msg){ tcp_api_call_t * msg = (tcp_api_call_t *)api_call_msg; - msg->err = tcp_write(msg->pcb, msg->write.data, msg->write.size, msg->write.apiflags); + msg->err = 0; + if(msg->client && msg->client->pcb()){ + msg->err = tcp_write(msg->pcb, msg->write.data, msg->write.size, msg->write.apiflags); + } return msg->err; } -static esp_err_t _tcp_write(tcp_pcb * pcb, const char* data, size_t size, uint8_t apiflags) { +static esp_err_t _tcp_write(tcp_pcb * pcb, const char* data, size_t size, uint8_t apiflags, AsyncClient * client) { if(!pcb){ log_w("pcb is NULL"); return ESP_FAIL; } tcp_api_call_t msg; msg.pcb = pcb; + msg.client = client; msg.write.data = data; msg.write.size = size; msg.write.apiflags = apiflags; @@ -389,22 +398,67 @@ static esp_err_t _tcp_write(tcp_pcb * pcb, const char* data, size_t size, uint8_ static err_t _tcp_recved_api(struct tcpip_api_call_data *api_call_msg){ tcp_api_call_t * msg = (tcp_api_call_t *)api_call_msg; msg->err = 0; - tcp_recved(msg->pcb, msg->received); + if(msg->client && msg->client->pcb()){ + tcp_recved(msg->pcb, msg->received); + } return msg->err; } -static esp_err_t _tcp_recved(tcp_pcb * pcb, size_t len) { +static esp_err_t _tcp_recved(tcp_pcb * pcb, size_t len, AsyncClient * client) { if(!pcb){ log_w("pcb is NULL"); return ESP_FAIL; } tcp_api_call_t msg; msg.pcb = pcb; + msg.client = client; msg.received = len; tcpip_api_call(_tcp_recved_api, (struct tcpip_api_call_data*)&msg); return msg.err; } +static err_t _tcp_close_api(struct tcpip_api_call_data *api_call_msg){ + tcp_api_call_t * msg = (tcp_api_call_t *)api_call_msg; + msg->err = 0; + if(msg->client && msg->client->pcb()){ + msg->err = tcp_close(msg->pcb); + } + return msg->err; +} + +static esp_err_t _tcp_close(tcp_pcb * pcb, AsyncClient * client) { + if(!pcb){ + log_w("pcb is NULL"); + return ESP_FAIL; + } + tcp_api_call_t msg; + msg.pcb = pcb; + msg.client = client; + tcpip_api_call(_tcp_close_api, (struct tcpip_api_call_data*)&msg); + return msg.err; +} + +static err_t _tcp_abort_api(struct tcpip_api_call_data *api_call_msg){ + tcp_api_call_t * msg = (tcp_api_call_t *)api_call_msg; + msg->err = 0; + if(msg->client && msg->client->pcb()){ + tcp_abort(msg->pcb); + } + return msg->err; +} + +static esp_err_t _tcp_abort(tcp_pcb * pcb, AsyncClient * client) { + if(!pcb){ + log_w("pcb is NULL"); + return ESP_FAIL; + } + tcp_api_call_t msg; + msg.pcb = pcb; + msg.client = client; + tcpip_api_call(_tcp_abort_api, (struct tcpip_api_call_data*)&msg); + return msg.err; +} + static err_t _tcp_connect_api(struct tcpip_api_call_data *api_call_msg){ tcp_api_call_t * msg = (tcp_api_call_t *)api_call_msg; msg->err = tcp_connect(msg->pcb, msg->connect.addr, msg->connect.port, msg->connect.cb); @@ -425,41 +479,6 @@ static esp_err_t _tcp_connect(tcp_pcb * pcb, ip_addr_t * addr, uint16_t port, tc return msg.err; } -static err_t _tcp_close_api(struct tcpip_api_call_data *api_call_msg){ - tcp_api_call_t * msg = (tcp_api_call_t *)api_call_msg; - msg->err = tcp_close(msg->pcb); - return msg->err; -} - -static esp_err_t _tcp_close(tcp_pcb * pcb) { - if(!pcb){ - log_w("pcb is NULL"); - return ESP_FAIL; - } - tcp_api_call_t msg; - msg.pcb = pcb; - tcpip_api_call(_tcp_close_api, (struct tcpip_api_call_data*)&msg); - return msg.err; -} - -static err_t _tcp_abort_api(struct tcpip_api_call_data *api_call_msg){ - tcp_api_call_t * msg = (tcp_api_call_t *)api_call_msg; - msg->err = 0; - tcp_abort(msg->pcb); - return msg->err; -} - -static esp_err_t _tcp_abort(tcp_pcb * pcb) { - if(!pcb){ - log_w("pcb is NULL"); - return ESP_FAIL; - } - tcp_api_call_t msg; - msg.pcb = pcb; - tcpip_api_call(_tcp_abort_api, (struct tcpip_api_call_data*)&msg); - return msg.err; -} - static err_t _tcp_bind_api(struct tcpip_api_call_data *api_call_msg){ tcp_api_call_t * msg = (tcp_api_call_t *)api_call_msg; msg->err = tcp_bind(msg->pcb, msg->bind.addr, msg->bind.port); @@ -606,7 +625,7 @@ int8_t AsyncClient::_close(){ tcp_err(_pcb, NULL); tcp_poll(_pcb, NULL, 0); _tcp_clear_events(this); - err = _tcp_close(_pcb); + err = _tcp_close(_pcb, this); if(err != ERR_OK) { err = abort(); } @@ -704,7 +723,7 @@ int8_t AsyncClient::_recv(tcp_pcb* pcb, pbuf* pb, int8_t err) { if(!_ack_pcb) { _rx_ack_len += b->len; } else if(_pcb) { - _tcp_recved(_pcb, b->len); + _tcp_recved(_pcb, b->len, this); } pbuf_free(b); } @@ -773,7 +792,7 @@ bool AsyncClient::connect(const char* host, uint16_t port){ int8_t AsyncClient::abort(){ if(_pcb) { - _tcp_abort(_pcb); + _tcp_abort(_pcb, this); _pcb = NULL; } return ERR_ABRT; @@ -781,7 +800,7 @@ int8_t AsyncClient::abort(){ void AsyncClient::close(bool now){ if(_pcb){ - _tcp_recved(_pcb, _rx_ack_len); + _tcp_recved(_pcb, _rx_ack_len, this); } _close(); } @@ -833,7 +852,7 @@ size_t AsyncClient::add(const char* data, size_t size, uint8_t apiflags) { } size_t will_send = (room < size) ? room : size; int8_t err = ERR_OK; - err = _tcp_write(_pcb, data, will_send, apiflags); + err = _tcp_write(_pcb, data, will_send, apiflags, this); if(err != ERR_OK) { return 0; } @@ -842,7 +861,7 @@ size_t AsyncClient::add(const char* data, size_t size, uint8_t apiflags) { bool AsyncClient::send(){ int8_t err = ERR_OK; - err = _tcp_output(_pcb); + err = _tcp_output(_pcb, this); if(err == ERR_OK){ _pcb_busy = true; _pcb_sent_at = millis(); @@ -855,7 +874,7 @@ size_t AsyncClient::ack(size_t len){ if(len > _rx_ack_len) len = _rx_ack_len; if(len){ - _tcp_recved(_pcb, len); + _tcp_recved(_pcb, len, this); } _rx_ack_len -= len; return len; @@ -996,7 +1015,7 @@ void AsyncClient::ackPacket(struct pbuf * pb){ if(!pb){ return; } - _tcp_recved(_pcb, pb->len); + _tcp_recved(_pcb, pb->len, this); pbuf_free(pb); } @@ -1232,7 +1251,7 @@ void AsyncServer::begin(){ err = _tcp_bind(_pcb, &local_addr, _port); if (err != ERR_OK) { - _tcp_close(_pcb); + _tcp_close(_pcb, NULL); log_e("bind error: %d", err); return; } @@ -1251,7 +1270,7 @@ void AsyncServer::end(){ if(_pcb){ tcp_arg(_pcb, NULL); tcp_accept(_pcb, NULL); - _tcp_abort(_pcb); + _tcp_abort(_pcb, NULL); _pcb = NULL; } } diff --git a/src/AsyncTCP.h b/src/AsyncTCP.h index 9f1f79a..05650fc 100644 --- a/src/AsyncTCP.h +++ b/src/AsyncTCP.h @@ -53,6 +53,92 @@ struct tcp_pcb; struct ip_addr; class AsyncClient { + public: + AsyncClient(tcp_pcb* pcb = 0); + ~AsyncClient(); + + AsyncClient & operator=(const AsyncClient &other); + AsyncClient & operator+=(const AsyncClient &other); + + bool operator==(const AsyncClient &other); + + bool operator!=(const AsyncClient &other) { + return !(*this == other); + } + bool connect(IPAddress ip, uint16_t port); + bool connect(const char* host, uint16_t port); + void close(bool now = false); + void stop(); + int8_t abort(); + bool free(); + + bool canSend();//ack is not pending + size_t space();//space available in the TCP window + size_t add(const char* data, size_t size, uint8_t apiflags=ASYNC_WRITE_FLAG_COPY);//add for sending + bool send();//send all data added with the method above + + //write equals add()+send() + size_t write(const char* data); + size_t write(const char* data, size_t size, uint8_t apiflags=ASYNC_WRITE_FLAG_COPY); //only when canSend() == true + + uint8_t state(); + bool connecting(); + bool connected(); + bool disconnecting(); + bool disconnected(); + bool freeable();//disconnected or disconnecting + + uint16_t getMss(); + + uint32_t getRxTimeout(); + void setRxTimeout(uint32_t timeout);//no RX data timeout for the connection in seconds + + uint32_t getAckTimeout(); + void setAckTimeout(uint32_t timeout);//no ACK timeout for the last sent packet in milliseconds + + void setNoDelay(bool nodelay); + bool getNoDelay(); + + uint32_t getRemoteAddress(); + uint16_t getRemotePort(); + uint32_t getLocalAddress(); + uint16_t getLocalPort(); + + //compatibility + IPAddress remoteIP(); + uint16_t remotePort(); + IPAddress localIP(); + uint16_t localPort(); + + void onConnect(AcConnectHandler cb, void* arg = 0); //on successful connect + void onDisconnect(AcConnectHandler cb, void* arg = 0); //disconnected + void onAck(AcAckHandler cb, void* arg = 0); //ack received + void onError(AcErrorHandler cb, void* arg = 0); //unsuccessful connect or error + void onData(AcDataHandler cb, void* arg = 0); //data received (called if onPacket is not used) + void onPacket(AcPacketHandler cb, void* arg = 0); //data received + void onTimeout(AcTimeoutHandler cb, void* arg = 0); //ack timeout + void onPoll(AcConnectHandler cb, void* arg = 0); //every 125ms when connected + + void ackPacket(struct pbuf * pb);//ack pbuf from onPacket + size_t ack(size_t len); //ack data that you have not acked using the method below + void ackLater(){ _ack_pcb = false; } //will not ack the current packet. Call from onData + + const char * errorToString(int8_t error); + const char * stateToString(); + + //Do not use any of the functions below! + static int8_t _s_poll(void *arg, struct tcp_pcb *tpcb); + static int8_t _s_recv(void *arg, struct tcp_pcb *tpcb, struct pbuf *pb, int8_t err); + static int8_t _s_fin(void *arg, struct tcp_pcb *tpcb, int8_t err); + static int8_t _s_lwip_fin(void *arg, struct tcp_pcb *tpcb, int8_t err); + static void _s_error(void *arg, int8_t err); + static int8_t _s_sent(void *arg, struct tcp_pcb *tpcb, uint16_t len); + static int8_t _s_connected(void* arg, void* tpcb, int8_t err); + static void _s_dns_found(const char *name, struct ip_addr *ipaddr, void *arg); + + int8_t _recv(tcp_pcb* pcb, pbuf* pb, int8_t err); + tcp_pcb * pcb(){ return _pcb; } + protected: tcp_pcb* _pcb; @@ -91,100 +177,13 @@ class AsyncClient { int8_t _lwip_fin(tcp_pcb* pcb, int8_t err); void _dns_found(struct ip_addr *ipaddr); - public: AsyncClient* prev; AsyncClient* next; - - AsyncClient(tcp_pcb* pcb = 0); - ~AsyncClient(); - - AsyncClient & operator=(const AsyncClient &other); - AsyncClient & operator+=(const AsyncClient &other); - - bool operator==(const AsyncClient &other); - - bool operator!=(const AsyncClient &other) { - return !(*this == other); - } - bool connect(IPAddress ip, uint16_t port); - bool connect(const char* host, uint16_t port); - void close(bool now = false); - void stop(); - int8_t abort(); - bool free(); - - bool canSend();//ack is not pending - size_t space(); - size_t add(const char* data, size_t size, uint8_t apiflags=ASYNC_WRITE_FLAG_COPY);//add for sending - bool send();//send all data added with the method above - size_t ack(size_t len); //ack data that you have not acked using the method below - void ackLater(){ _ack_pcb = false; } //will not ack the current packet. Call from onData - - size_t write(const char* data); - size_t write(const char* data, size_t size, uint8_t apiflags=ASYNC_WRITE_FLAG_COPY); //only when canSend() == true - - uint8_t state(); - bool connecting(); - bool connected(); - bool disconnecting(); - bool disconnected(); - bool freeable();//disconnected or disconnecting - - uint16_t getMss(); - uint32_t getRxTimeout(); - void setRxTimeout(uint32_t timeout);//no RX data timeout for the connection in seconds - uint32_t getAckTimeout(); - void setAckTimeout(uint32_t timeout);//no ACK timeout for the last sent packet in milliseconds - void setNoDelay(bool nodelay); - bool getNoDelay(); - uint32_t getRemoteAddress(); - uint16_t getRemotePort(); - uint32_t getLocalAddress(); - uint16_t getLocalPort(); - - IPAddress remoteIP(); - uint16_t remotePort(); - IPAddress localIP(); - uint16_t localPort(); - - void onConnect(AcConnectHandler cb, void* arg = 0); //on successful connect - void onDisconnect(AcConnectHandler cb, void* arg = 0); //disconnected - void onAck(AcAckHandler cb, void* arg = 0); //ack received - void onError(AcErrorHandler cb, void* arg = 0); //unsuccessful connect or error - void onData(AcDataHandler cb, void* arg = 0); //data received (called if onPacket is not used) - void onPacket(AcPacketHandler cb, void* arg = 0); //data received - void onTimeout(AcTimeoutHandler cb, void* arg = 0); //ack timeout - void onPoll(AcConnectHandler cb, void* arg = 0); //every 125ms when connected - - void ackPacket(struct pbuf * pb); - - const char * errorToString(int8_t error); - const char * stateToString(); - - int8_t _recv(tcp_pcb* pcb, pbuf* pb, int8_t err); - - static int8_t _s_poll(void *arg, struct tcp_pcb *tpcb); - static int8_t _s_recv(void *arg, struct tcp_pcb *tpcb, struct pbuf *pb, int8_t err); - static int8_t _s_fin(void *arg, struct tcp_pcb *tpcb, int8_t err); - static int8_t _s_lwip_fin(void *arg, struct tcp_pcb *tpcb, int8_t err); - static void _s_error(void *arg, int8_t err); - static int8_t _s_sent(void *arg, struct tcp_pcb *tpcb, uint16_t len); - static int8_t _s_connected(void* arg, void* tpcb, int8_t err); - static void _s_dns_found(const char *name, struct ip_addr *ipaddr, void *arg); }; class AsyncServer { - protected: - uint16_t _port; - IPAddress _addr; - bool _noDelay; - tcp_pcb* _pcb; - AcConnectHandler _connect_cb; - void* _connect_cb_arg; - public: - AsyncServer(IPAddress addr, uint16_t port); AsyncServer(uint16_t port); ~AsyncServer(); @@ -195,9 +194,18 @@ class AsyncServer { bool getNoDelay(); uint8_t status(); + //Do not use any of the functions below! static int8_t _s_accept(void *arg, tcp_pcb* newpcb, int8_t err); static int8_t _s_accepted(void *arg, AsyncClient* client); + protected: + uint16_t _port; + IPAddress _addr; + bool _noDelay; + tcp_pcb* _pcb; + AcConnectHandler _connect_cb; + void* _connect_cb_arg; + int8_t _accept(tcp_pcb* newpcb, int8_t err); int8_t _accepted(AsyncClient* client); };