try to catch some edge cases between LwIP and Async tasks

This commit is contained in:
me-no-dev
2019-06-24 10:23:58 +02:00
parent b5c6167a3f
commit 4e7d1c3a2d
2 changed files with 165 additions and 138 deletions

View File

@@ -328,6 +328,7 @@ static int8_t _tcp_accept(void * arg, AsyncClient * client) {
typedef struct {
struct tcpip_api_call_data call;
tcp_pcb * pcb;
AsyncClient * client;
int8_t err;
union {
struct {
@@ -351,34 +352,42 @@ typedef struct {
static err_t _tcp_output_api(struct tcpip_api_call_data *api_call_msg){
tcp_api_call_t * msg = (tcp_api_call_t *)api_call_msg;
msg->err = tcp_output(msg->pcb);
msg->err = 0;
if(msg->client && msg->client->pcb()){
msg->err = tcp_output(msg->pcb);
}
return msg->err;
}
static esp_err_t _tcp_output(tcp_pcb * pcb) {
static esp_err_t _tcp_output(tcp_pcb * pcb, AsyncClient * client) {
if(!pcb){
log_w("pcb is NULL");
return ESP_FAIL;
}
tcp_api_call_t msg;
msg.pcb = pcb;
msg.client = client;
tcpip_api_call(_tcp_output_api, (struct tcpip_api_call_data*)&msg);
return msg.err;
}
static err_t _tcp_write_api(struct tcpip_api_call_data *api_call_msg){
tcp_api_call_t * msg = (tcp_api_call_t *)api_call_msg;
msg->err = tcp_write(msg->pcb, msg->write.data, msg->write.size, msg->write.apiflags);
msg->err = 0;
if(msg->client && msg->client->pcb()){
msg->err = tcp_write(msg->pcb, msg->write.data, msg->write.size, msg->write.apiflags);
}
return msg->err;
}
static esp_err_t _tcp_write(tcp_pcb * pcb, const char* data, size_t size, uint8_t apiflags) {
static esp_err_t _tcp_write(tcp_pcb * pcb, const char* data, size_t size, uint8_t apiflags, AsyncClient * client) {
if(!pcb){
log_w("pcb is NULL");
return ESP_FAIL;
}
tcp_api_call_t msg;
msg.pcb = pcb;
msg.client = client;
msg.write.data = data;
msg.write.size = size;
msg.write.apiflags = apiflags;
@@ -389,22 +398,67 @@ static esp_err_t _tcp_write(tcp_pcb * pcb, const char* data, size_t size, uint8_
static err_t _tcp_recved_api(struct tcpip_api_call_data *api_call_msg){
tcp_api_call_t * msg = (tcp_api_call_t *)api_call_msg;
msg->err = 0;
tcp_recved(msg->pcb, msg->received);
if(msg->client && msg->client->pcb()){
tcp_recved(msg->pcb, msg->received);
}
return msg->err;
}
static esp_err_t _tcp_recved(tcp_pcb * pcb, size_t len) {
static esp_err_t _tcp_recved(tcp_pcb * pcb, size_t len, AsyncClient * client) {
if(!pcb){
log_w("pcb is NULL");
return ESP_FAIL;
}
tcp_api_call_t msg;
msg.pcb = pcb;
msg.client = client;
msg.received = len;
tcpip_api_call(_tcp_recved_api, (struct tcpip_api_call_data*)&msg);
return msg.err;
}
static err_t _tcp_close_api(struct tcpip_api_call_data *api_call_msg){
tcp_api_call_t * msg = (tcp_api_call_t *)api_call_msg;
msg->err = 0;
if(msg->client && msg->client->pcb()){
msg->err = tcp_close(msg->pcb);
}
return msg->err;
}
static esp_err_t _tcp_close(tcp_pcb * pcb, AsyncClient * client) {
if(!pcb){
log_w("pcb is NULL");
return ESP_FAIL;
}
tcp_api_call_t msg;
msg.pcb = pcb;
msg.client = client;
tcpip_api_call(_tcp_close_api, (struct tcpip_api_call_data*)&msg);
return msg.err;
}
static err_t _tcp_abort_api(struct tcpip_api_call_data *api_call_msg){
tcp_api_call_t * msg = (tcp_api_call_t *)api_call_msg;
msg->err = 0;
if(msg->client && msg->client->pcb()){
tcp_abort(msg->pcb);
}
return msg->err;
}
static esp_err_t _tcp_abort(tcp_pcb * pcb, AsyncClient * client) {
if(!pcb){
log_w("pcb is NULL");
return ESP_FAIL;
}
tcp_api_call_t msg;
msg.pcb = pcb;
msg.client = client;
tcpip_api_call(_tcp_abort_api, (struct tcpip_api_call_data*)&msg);
return msg.err;
}
static err_t _tcp_connect_api(struct tcpip_api_call_data *api_call_msg){
tcp_api_call_t * msg = (tcp_api_call_t *)api_call_msg;
msg->err = tcp_connect(msg->pcb, msg->connect.addr, msg->connect.port, msg->connect.cb);
@@ -425,41 +479,6 @@ static esp_err_t _tcp_connect(tcp_pcb * pcb, ip_addr_t * addr, uint16_t port, tc
return msg.err;
}
static err_t _tcp_close_api(struct tcpip_api_call_data *api_call_msg){
tcp_api_call_t * msg = (tcp_api_call_t *)api_call_msg;
msg->err = tcp_close(msg->pcb);
return msg->err;
}
static esp_err_t _tcp_close(tcp_pcb * pcb) {
if(!pcb){
log_w("pcb is NULL");
return ESP_FAIL;
}
tcp_api_call_t msg;
msg.pcb = pcb;
tcpip_api_call(_tcp_close_api, (struct tcpip_api_call_data*)&msg);
return msg.err;
}
static err_t _tcp_abort_api(struct tcpip_api_call_data *api_call_msg){
tcp_api_call_t * msg = (tcp_api_call_t *)api_call_msg;
msg->err = 0;
tcp_abort(msg->pcb);
return msg->err;
}
static esp_err_t _tcp_abort(tcp_pcb * pcb) {
if(!pcb){
log_w("pcb is NULL");
return ESP_FAIL;
}
tcp_api_call_t msg;
msg.pcb = pcb;
tcpip_api_call(_tcp_abort_api, (struct tcpip_api_call_data*)&msg);
return msg.err;
}
static err_t _tcp_bind_api(struct tcpip_api_call_data *api_call_msg){
tcp_api_call_t * msg = (tcp_api_call_t *)api_call_msg;
msg->err = tcp_bind(msg->pcb, msg->bind.addr, msg->bind.port);
@@ -606,7 +625,7 @@ int8_t AsyncClient::_close(){
tcp_err(_pcb, NULL);
tcp_poll(_pcb, NULL, 0);
_tcp_clear_events(this);
err = _tcp_close(_pcb);
err = _tcp_close(_pcb, this);
if(err != ERR_OK) {
err = abort();
}
@@ -704,7 +723,7 @@ int8_t AsyncClient::_recv(tcp_pcb* pcb, pbuf* pb, int8_t err) {
if(!_ack_pcb) {
_rx_ack_len += b->len;
} else if(_pcb) {
_tcp_recved(_pcb, b->len);
_tcp_recved(_pcb, b->len, this);
}
pbuf_free(b);
}
@@ -773,7 +792,7 @@ bool AsyncClient::connect(const char* host, uint16_t port){
int8_t AsyncClient::abort(){
if(_pcb) {
_tcp_abort(_pcb);
_tcp_abort(_pcb, this);
_pcb = NULL;
}
return ERR_ABRT;
@@ -781,7 +800,7 @@ int8_t AsyncClient::abort(){
void AsyncClient::close(bool now){
if(_pcb){
_tcp_recved(_pcb, _rx_ack_len);
_tcp_recved(_pcb, _rx_ack_len, this);
}
_close();
}
@@ -833,7 +852,7 @@ size_t AsyncClient::add(const char* data, size_t size, uint8_t apiflags) {
}
size_t will_send = (room < size) ? room : size;
int8_t err = ERR_OK;
err = _tcp_write(_pcb, data, will_send, apiflags);
err = _tcp_write(_pcb, data, will_send, apiflags, this);
if(err != ERR_OK) {
return 0;
}
@@ -842,7 +861,7 @@ size_t AsyncClient::add(const char* data, size_t size, uint8_t apiflags) {
bool AsyncClient::send(){
int8_t err = ERR_OK;
err = _tcp_output(_pcb);
err = _tcp_output(_pcb, this);
if(err == ERR_OK){
_pcb_busy = true;
_pcb_sent_at = millis();
@@ -855,7 +874,7 @@ size_t AsyncClient::ack(size_t len){
if(len > _rx_ack_len)
len = _rx_ack_len;
if(len){
_tcp_recved(_pcb, len);
_tcp_recved(_pcb, len, this);
}
_rx_ack_len -= len;
return len;
@@ -996,7 +1015,7 @@ void AsyncClient::ackPacket(struct pbuf * pb){
if(!pb){
return;
}
_tcp_recved(_pcb, pb->len);
_tcp_recved(_pcb, pb->len, this);
pbuf_free(pb);
}
@@ -1232,7 +1251,7 @@ void AsyncServer::begin(){
err = _tcp_bind(_pcb, &local_addr, _port);
if (err != ERR_OK) {
_tcp_close(_pcb);
_tcp_close(_pcb, NULL);
log_e("bind error: %d", err);
return;
}
@@ -1251,7 +1270,7 @@ void AsyncServer::end(){
if(_pcb){
tcp_arg(_pcb, NULL);
tcp_accept(_pcb, NULL);
_tcp_abort(_pcb);
_tcp_abort(_pcb, NULL);
_pcb = NULL;
}
}

View File

@@ -53,6 +53,92 @@ struct tcp_pcb;
struct ip_addr;
class AsyncClient {
public:
AsyncClient(tcp_pcb* pcb = 0);
~AsyncClient();
AsyncClient & operator=(const AsyncClient &other);
AsyncClient & operator+=(const AsyncClient &other);
bool operator==(const AsyncClient &other);
bool operator!=(const AsyncClient &other) {
return !(*this == other);
}
bool connect(IPAddress ip, uint16_t port);
bool connect(const char* host, uint16_t port);
void close(bool now = false);
void stop();
int8_t abort();
bool free();
bool canSend();//ack is not pending
size_t space();//space available in the TCP window
size_t add(const char* data, size_t size, uint8_t apiflags=ASYNC_WRITE_FLAG_COPY);//add for sending
bool send();//send all data added with the method above
//write equals add()+send()
size_t write(const char* data);
size_t write(const char* data, size_t size, uint8_t apiflags=ASYNC_WRITE_FLAG_COPY); //only when canSend() == true
uint8_t state();
bool connecting();
bool connected();
bool disconnecting();
bool disconnected();
bool freeable();//disconnected or disconnecting
uint16_t getMss();
uint32_t getRxTimeout();
void setRxTimeout(uint32_t timeout);//no RX data timeout for the connection in seconds
uint32_t getAckTimeout();
void setAckTimeout(uint32_t timeout);//no ACK timeout for the last sent packet in milliseconds
void setNoDelay(bool nodelay);
bool getNoDelay();
uint32_t getRemoteAddress();
uint16_t getRemotePort();
uint32_t getLocalAddress();
uint16_t getLocalPort();
//compatibility
IPAddress remoteIP();
uint16_t remotePort();
IPAddress localIP();
uint16_t localPort();
void onConnect(AcConnectHandler cb, void* arg = 0); //on successful connect
void onDisconnect(AcConnectHandler cb, void* arg = 0); //disconnected
void onAck(AcAckHandler cb, void* arg = 0); //ack received
void onError(AcErrorHandler cb, void* arg = 0); //unsuccessful connect or error
void onData(AcDataHandler cb, void* arg = 0); //data received (called if onPacket is not used)
void onPacket(AcPacketHandler cb, void* arg = 0); //data received
void onTimeout(AcTimeoutHandler cb, void* arg = 0); //ack timeout
void onPoll(AcConnectHandler cb, void* arg = 0); //every 125ms when connected
void ackPacket(struct pbuf * pb);//ack pbuf from onPacket
size_t ack(size_t len); //ack data that you have not acked using the method below
void ackLater(){ _ack_pcb = false; } //will not ack the current packet. Call from onData
const char * errorToString(int8_t error);
const char * stateToString();
//Do not use any of the functions below!
static int8_t _s_poll(void *arg, struct tcp_pcb *tpcb);
static int8_t _s_recv(void *arg, struct tcp_pcb *tpcb, struct pbuf *pb, int8_t err);
static int8_t _s_fin(void *arg, struct tcp_pcb *tpcb, int8_t err);
static int8_t _s_lwip_fin(void *arg, struct tcp_pcb *tpcb, int8_t err);
static void _s_error(void *arg, int8_t err);
static int8_t _s_sent(void *arg, struct tcp_pcb *tpcb, uint16_t len);
static int8_t _s_connected(void* arg, void* tpcb, int8_t err);
static void _s_dns_found(const char *name, struct ip_addr *ipaddr, void *arg);
int8_t _recv(tcp_pcb* pcb, pbuf* pb, int8_t err);
tcp_pcb * pcb(){ return _pcb; }
protected:
tcp_pcb* _pcb;
@@ -91,100 +177,13 @@ class AsyncClient {
int8_t _lwip_fin(tcp_pcb* pcb, int8_t err);
void _dns_found(struct ip_addr *ipaddr);
public:
AsyncClient* prev;
AsyncClient* next;
AsyncClient(tcp_pcb* pcb = 0);
~AsyncClient();
AsyncClient & operator=(const AsyncClient &other);
AsyncClient & operator+=(const AsyncClient &other);
bool operator==(const AsyncClient &other);
bool operator!=(const AsyncClient &other) {
return !(*this == other);
}
bool connect(IPAddress ip, uint16_t port);
bool connect(const char* host, uint16_t port);
void close(bool now = false);
void stop();
int8_t abort();
bool free();
bool canSend();//ack is not pending
size_t space();
size_t add(const char* data, size_t size, uint8_t apiflags=ASYNC_WRITE_FLAG_COPY);//add for sending
bool send();//send all data added with the method above
size_t ack(size_t len); //ack data that you have not acked using the method below
void ackLater(){ _ack_pcb = false; } //will not ack the current packet. Call from onData
size_t write(const char* data);
size_t write(const char* data, size_t size, uint8_t apiflags=ASYNC_WRITE_FLAG_COPY); //only when canSend() == true
uint8_t state();
bool connecting();
bool connected();
bool disconnecting();
bool disconnected();
bool freeable();//disconnected or disconnecting
uint16_t getMss();
uint32_t getRxTimeout();
void setRxTimeout(uint32_t timeout);//no RX data timeout for the connection in seconds
uint32_t getAckTimeout();
void setAckTimeout(uint32_t timeout);//no ACK timeout for the last sent packet in milliseconds
void setNoDelay(bool nodelay);
bool getNoDelay();
uint32_t getRemoteAddress();
uint16_t getRemotePort();
uint32_t getLocalAddress();
uint16_t getLocalPort();
IPAddress remoteIP();
uint16_t remotePort();
IPAddress localIP();
uint16_t localPort();
void onConnect(AcConnectHandler cb, void* arg = 0); //on successful connect
void onDisconnect(AcConnectHandler cb, void* arg = 0); //disconnected
void onAck(AcAckHandler cb, void* arg = 0); //ack received
void onError(AcErrorHandler cb, void* arg = 0); //unsuccessful connect or error
void onData(AcDataHandler cb, void* arg = 0); //data received (called if onPacket is not used)
void onPacket(AcPacketHandler cb, void* arg = 0); //data received
void onTimeout(AcTimeoutHandler cb, void* arg = 0); //ack timeout
void onPoll(AcConnectHandler cb, void* arg = 0); //every 125ms when connected
void ackPacket(struct pbuf * pb);
const char * errorToString(int8_t error);
const char * stateToString();
int8_t _recv(tcp_pcb* pcb, pbuf* pb, int8_t err);
static int8_t _s_poll(void *arg, struct tcp_pcb *tpcb);
static int8_t _s_recv(void *arg, struct tcp_pcb *tpcb, struct pbuf *pb, int8_t err);
static int8_t _s_fin(void *arg, struct tcp_pcb *tpcb, int8_t err);
static int8_t _s_lwip_fin(void *arg, struct tcp_pcb *tpcb, int8_t err);
static void _s_error(void *arg, int8_t err);
static int8_t _s_sent(void *arg, struct tcp_pcb *tpcb, uint16_t len);
static int8_t _s_connected(void* arg, void* tpcb, int8_t err);
static void _s_dns_found(const char *name, struct ip_addr *ipaddr, void *arg);
};
class AsyncServer {
protected:
uint16_t _port;
IPAddress _addr;
bool _noDelay;
tcp_pcb* _pcb;
AcConnectHandler _connect_cb;
void* _connect_cb_arg;
public:
AsyncServer(IPAddress addr, uint16_t port);
AsyncServer(uint16_t port);
~AsyncServer();
@@ -195,9 +194,18 @@ class AsyncServer {
bool getNoDelay();
uint8_t status();
//Do not use any of the functions below!
static int8_t _s_accept(void *arg, tcp_pcb* newpcb, int8_t err);
static int8_t _s_accepted(void *arg, AsyncClient* client);
protected:
uint16_t _port;
IPAddress _addr;
bool _noDelay;
tcp_pcb* _pcb;
AcConnectHandler _connect_cb;
void* _connect_cb_arg;
int8_t _accept(tcp_pcb* newpcb, int8_t err);
int8_t _accepted(AsyncClient* client);
};