diff --git a/components/network/websocket/luat_lib_websocket.c b/components/network/websocket/luat_lib_websocket.c index a5918ae99..445190f62 100644 --- a/components/network/websocket/luat_lib_websocket.c +++ b/components/network/websocket/luat_lib_websocket.c @@ -57,6 +57,7 @@ static luat_websocket_ctrl_t *get_websocket_ctrl(lua_State *L) static int32_t l_websocket_callback(lua_State *L, void *ptr) { + (void)ptr; rtos_msg_t *msg = (rtos_msg_t *)lua_topointer(L, -1); luat_websocket_ctrl_t *websocket_ctrl = (luat_websocket_ctrl_t *)msg->ptr; luat_websocket_pkg_t pkg = {0}; @@ -198,7 +199,7 @@ static int l_websocket_create(lua_State *L) luat_websocket_connopts_t opts = {0}; // 连接参数相关 - const char *ip; + // const char *ip; size_t ip_len = 0; #ifdef LUAT_USE_LWIP websocket_ctrl->ip_addr.type = 0xff; @@ -295,13 +296,13 @@ static int l_websocket_autoreconn(lua_State *L) @string 待发送的数据,必填 @int 是否为最后一帧,默认1 @int 操作码, 默认为字符串帧 -@return int 消息id, 当qos为1或2时会有效值. 若底层返回是否, 会返回nil +@return bool 成功返回true,否则为false或者nil @usage wsc:publish("/luatos/123456", "123") */ static int l_websocket_send(lua_State *L) { - uint32_t payload_len = 0; + size_t payload_len = 0; luat_websocket_ctrl_t *websocket_ctrl = get_websocket_ctrl(L); const char *payload = NULL; luat_zbuff_t *buff = NULL; @@ -327,7 +328,8 @@ static int l_websocket_send(lua_State *L) .plen = payload_len, .payload = payload}; ret = luat_websocket_send_frame(websocket_ctrl, &pkg); - return 0; + lua_pushboolean(L, ret == 0 ? 1 : 0); + return 1; } /* @@ -355,7 +357,7 @@ websocket客户端是否就绪 @api wsc:ready() @return bool 客户端是否就绪 @usage -local error = wsc:ready() +local stat = wsc:ready() */ static int l_websocket_ready(lua_State *L) { @@ -364,6 +366,82 @@ static int l_websocket_ready(lua_State *L) return 1; } +/* +设置额外的headers +@api wsc:headers(headers) +@table/string 可以是table,也可以是字符串 +@return bool 客户端是否就绪 +@usage +-- table形式 +wsc:headers({ + Auth="Basic ABCDEFGG" +}) +-- 字符串形式 +wsc:headers("Auth: Basic ABCDERG\r\n") +*/ +static int l_websocket_headers(lua_State *L) +{ + luat_websocket_ctrl_t *websocket_ctrl = get_websocket_ctrl(L); + if (!lua_istable(L, 2) && !lua_isstring(L, 2)) { + return 0; + } + #define WS_HEADER_MAX (1024) + char* buff = luat_heap_malloc(WS_HEADER_MAX); + memset(buff, 0, WS_HEADER_MAX); + if (lua_istable(L, 2)) { + size_t name_sz = 0; + size_t value_sz = 0; + lua_pushnil(L); + while (lua_next(L, 2) != 0) { + const char *name = lua_tolstring(L, -2, &name_sz); + const char *value = lua_tolstring(L, -1, &value_sz); + if (name_sz == 0 || value_sz == 0 || name_sz + value_sz > 256) { + LLOGW("bad header %s %s", name, value); + luat_heap_free(buff); + return 0; + } + memcpy(buff + strlen(buff), name, name_sz); + memcpy(buff + strlen(buff), ":", 1); + if (WS_HEADER_MAX - strlen(buff) < value_sz * 2) { + LLOGW("bad header %s %s, too large", name, value); + luat_heap_free(buff); + return 0; + } + for (size_t i = 0; i < value_sz; i++) + { + switch (value[i]) + { + case '*': + case '-': + case '.': + case '_': + case ' ': + sprintf_(buff + strlen(buff), "%%%02X", value[i]); + break; + default: + buff[strlen(buff)] = value[i]; + break; + } + } + lua_pop(L, 1); + memcpy(buff + strlen(buff), "\r\n", 2); + } + } + else { + size_t len = 0; + const char* data = luaL_checklstring(L, 2, &len); + if (len > 1023) { + LLOGW("headers too large size %d", len); + luat_heap_free(buff); + return 0; + } + memcpy(buff, data, len); + } + luat_websocket_set_headers(websocket_ctrl, buff); + lua_pushboolean(L, 1); + return 1; +} + static int _websocket_struct_newindex(lua_State *L); void luat_websocket_struct_init(lua_State *L) @@ -377,19 +455,22 @@ void luat_websocket_struct_init(lua_State *L) #include "rotable2.h" const rotable_Reg_t reg_websocket[] = { - {"create", ROREG_FUNC(l_websocket_create)}, - {"on", ROREG_FUNC(l_websocket_on)}, - {"connect", ROREG_FUNC(l_websocket_connect)}, - {"autoreconn", ROREG_FUNC(l_websocket_autoreconn)}, - {"send", ROREG_FUNC(l_websocket_send)}, - {"close", ROREG_FUNC(l_websocket_close)}, - {"ready", ROREG_FUNC(l_websocket_ready)}, + {"create", ROREG_FUNC(l_websocket_create)}, + {"on", ROREG_FUNC(l_websocket_on)}, + {"connect", ROREG_FUNC(l_websocket_connect)}, + {"autoreconn", ROREG_FUNC(l_websocket_autoreconn)}, + {"send", ROREG_FUNC(l_websocket_send)}, + {"close", ROREG_FUNC(l_websocket_close)}, + {"ready", ROREG_FUNC(l_websocket_ready)}, + {"headers", ROREG_FUNC(l_websocket_headers)}, + {"debug", ROREG_FUNC(l_websocket_set_debug)}, - {NULL, ROREG_INT(0)}}; + {NULL, ROREG_INT(0)} +}; int _websocket_struct_newindex(lua_State *L) { - rotable_Reg_t *reg = reg_websocket; + const rotable_Reg_t *reg = reg_websocket; const char *key = luaL_checkstring(L, 2); while (1) { diff --git a/components/network/websocket/luat_websocket.c b/components/network/websocket/luat_websocket.c index 2c162ad36..2f25345cf 100644 --- a/components/network/websocket/luat_websocket.c +++ b/components/network/websocket/luat_websocket.c @@ -281,6 +281,10 @@ void luat_websocket_release_socket(luat_websocket_ctrl_t *websocket_ctrl) luat_release_rtos_timer(websocket_ctrl->reconnect_timer); websocket_ctrl->reconnect_timer = NULL; } + if (websocket_ctrl->headers) { + luat_heap_free(websocket_ctrl->headers); + websocket_ctrl->headers = NULL; + } if (websocket_ctrl->netc) { network_release_ctrl(websocket_ctrl->netc); @@ -288,6 +292,13 @@ void luat_websocket_release_socket(luat_websocket_ctrl_t *websocket_ctrl) } } +static const char* ws_headers = + "Upgrade: websocket\r\n" + "Connection: Upgrade\r\n" + "Sec-WebSocket-Key: w4v7O6xFTi36lq3RNcgctw==\r\n" + "Sec-WebSocket-Version: 13\r\n" + "\r\n"; + static int websocket_connect(luat_websocket_ctrl_t *websocket_ctrl) { LLOGD("request host %s port %d uri %s", websocket_ctrl->host, websocket_ctrl->remote_port, websocket_ctrl->uri); @@ -295,15 +306,14 @@ static int websocket_connect(luat_websocket_ctrl_t *websocket_ctrl) int ret = snprintf_((char*)websocket_ctrl->pkg_buff, WEBSOCKET_RECV_BUF_LEN_MAX, "GET %s HTTP/1.1\r\n" - "Host: %s\r\n" - "Upgrade: websocket\r\n" - "Connection: Upgrade\r\n" - "Sec-WebSocket-Key: w4v7O6xFTi36lq3RNcgctw==\r\n" - "Sec-WebSocket-Version: 13\r\n" - "\r\n", + "Host: %s\r\n", websocket_ctrl->uri, websocket_ctrl->host); - LLOGD("Request %s", websocket_ctrl->pkg_buff); + //LLOGD("Request %s", websocket_ctrl->pkg_buff); ret = luat_websocket_send_packet(websocket_ctrl, websocket_ctrl->pkg_buff, ret); + if (websocket_ctrl->headers) { + luat_websocket_send_packet(websocket_ctrl, websocket_ctrl->headers, strlen(websocket_ctrl->headers)); + } + luat_websocket_send_packet(websocket_ctrl, ws_headers, strlen(ws_headers)); LLOGD("websocket_connect ret %d", ret); return ret; } @@ -441,7 +451,7 @@ static int websocket_parse(luat_websocket_ctrl_t *websocket_ctrl) return -1; } memcpy(buff, buf, pkg_len); - l_luat_websocket_msg_cb(websocket_ctrl, WEBSOCKET_MSG_PUBLISH, buff); + l_luat_websocket_msg_cb(websocket_ctrl, WEBSOCKET_MSG_PUBLISH, (int)buff); } // 处理完成后, 如果还有数据, 移动数据, 继续处理 @@ -457,8 +467,8 @@ static int websocket_parse(luat_websocket_ctrl_t *websocket_ctrl) int luat_websocket_read_packet(luat_websocket_ctrl_t *websocket_ctrl) { // LLOGD("luat_websocket_read_packet websocket_ctrl->buffer_offset:%d",websocket_ctrl->buffer_offset); - int ret = -1; - uint8_t *read_buff = NULL; + // int ret = -1; + // uint8_t *read_buff = NULL; uint32_t total_len = 0; uint32_t rx_len = 0; int result = network_rx(websocket_ctrl->netc, NULL, 0, 0, NULL, NULL, &total_len); @@ -628,3 +638,16 @@ int luat_websocket_connect(luat_websocket_ctrl_t *websocket_ctrl) } return 0; } + +int luat_websocket_set_headers(luat_websocket_ctrl_t *websocket_ctrl, const char *headers) { + if (websocket_ctrl == NULL) + return 0; + if (websocket_ctrl->headers != NULL) { + luat_heap_free(websocket_ctrl->headers); + websocket_ctrl->headers = NULL; + } + if (headers) { + websocket_ctrl->headers = headers; + } + return 0; +} diff --git a/components/network/websocket/luat_websocket.h b/components/network/websocket/luat_websocket.h index 73aee89ad..a594a6c5c 100644 --- a/components/network/websocket/luat_websocket.h +++ b/components/network/websocket/luat_websocket.h @@ -32,6 +32,7 @@ typedef struct void *reconnect_timer; // websocket重连定时器 void *ping_timer; // websocket_ping定时器 int websocket_ref; // 强制引用自身避免被GC + char* headers; } luat_websocket_ctrl_t; typedef struct luat_websocket_connopts @@ -69,5 +70,5 @@ int luat_websocket_init(luat_websocket_ctrl_t *websocket_ctrl, int adapter_index int luat_websocket_set_connopts(luat_websocket_ctrl_t *websocket_ctrl, const char *url); int luat_websocket_payload(char *buff, luat_websocket_pkg_t *pkg, size_t limit); int luat_websocket_send_frame(luat_websocket_ctrl_t *websocket_ctrl, luat_websocket_pkg_t *pkg); - +int luat_websocket_set_headers(luat_websocket_ctrl_t *websocket_ctrl, const char *headers); #endif diff --git a/demo/websocket/main.lua b/demo/websocket/main.lua index 6e2375c7b..1318be87c 100644 --- a/demo/websocket/main.lua +++ b/demo/websocket/main.lua @@ -19,8 +19,8 @@ local wsc = nil sys.taskInit(function() if rtos.bsp():startsWith("ESP32") then - local ssid = "uiot123" - local password = "12348888" + local ssid = "uiot" + local password = "1234567890" log.info("wifi", ssid, password) -- TODO 改成esptouch配网 LED = gpio.setup(12, 0, gpio.PULLUP) @@ -46,6 +46,9 @@ sys.taskInit(function() -- 这是个测试服务, 当发送的是json,且action=echo,就会回显所发送的内容 wsc = websocket.create(nil, "ws://echo.airtun.air32.cn/ws/echo") + if wsc.headers then + wsc:headers({Auth="Basic ABCDEGG"}) + end wsc:autoreconn(true, 3000) -- 自动重连机制 wsc:on(function(wsc, event, data, fin, optcode) -- event 事件, 当前有conack和recv