diff --git a/src/internal.c b/src/internal.c index c48d11b02..b8e88d114 100644 --- a/src/internal.c +++ b/src/internal.c @@ -584,6 +584,40 @@ static void HandshakeInfoFree(HandshakeInfo* hs, void* heap) } +/* RFC 4253 section 7.1, Once having sent SSH_MSG_KEXINIT the only messages +* that can be sent are 1-19 (except SSH_MSG_SERVICE_REQUEST and +* SSH_MSG_SERVICE_ACCEPT), 20-29 (except SSH_MSG_KEXINIT again), and 30-49 +*/ +INLINE static int IsMessageAllowedKeying(WOLFSSH *ssh, byte msg) +{ + if (ssh->isKeying == 0) { + return 1; + } + + /* case of servie request or accept in 1-19 */ + if (msg == MSGID_SERVICE_REQUEST || msg == MSGID_SERVICE_ACCEPT) { + WLOG(WS_LOG_DEBUG, "Message ID %u not allowed by during rekeying", msg); + ssh->error = WS_REKEYING; + return 0; + } + + /* case of resending SSH_MSG_KEXINIT */ + if (msg == MSGID_KEXINIT) { + WLOG(WS_LOG_DEBUG, "Message ID %u not allowed by during rekeying", msg); + ssh->error = WS_REKEYING; + return 0; + } + + /* case where message id greater than 49 */ + if (msg >= MSGID_USERAUTH_REQUEST) { + WLOG(WS_LOG_DEBUG, "Message ID %u not allowed by during rekeying", msg); + ssh->error = WS_REKEYING; + return 0; + } + return 1; +} + + #ifndef NO_WOLFSSH_SERVER INLINE static int IsMessageAllowedServer(WOLFSSH *ssh, byte msg) { @@ -662,8 +696,12 @@ INLINE static int IsMessageAllowedClient(WOLFSSH *ssh, byte msg) #endif /* NO_WOLFSSH_CLIENT */ -INLINE static int IsMessageAllowed(WOLFSSH *ssh, byte msg) +INLINE static int IsMessageAllowed(WOLFSSH *ssh, byte msg, byte state) { + if (state == WS_MSG_SEND && !IsMessageAllowedKeying(ssh, msg)) { + return 0; + } + #ifndef NO_WOLFSSH_SERVER if (ssh->ctx->side == WOLFSSH_ENDPOINT_SERVER) { return IsMessageAllowedServer(ssh, msg); @@ -5808,7 +5846,6 @@ static int DoNewKeys(WOLFSSH* ssh, byte* buf, word32 len, word32* idx) HandshakeInfoFree(ssh->handshake, ssh->ctx->heap); ssh->handshake = NULL; WLOG(WS_LOG_DEBUG, "Keying completed"); - if (ssh->ctx->keyingCompletionCb) ssh->ctx->keyingCompletionCb(ssh->keyingCompletionCtx); } @@ -9178,7 +9215,7 @@ static int DoPacket(WOLFSSH* ssh, byte* bufferConsumed) return WS_OVERFLOW_E; } - if (!IsMessageAllowed(ssh, msg)) { + if (!IsMessageAllowed(ssh, msg, WS_MSG_RECV)) { return WS_MSGID_NOT_ALLOWED_E; } @@ -15425,6 +15462,12 @@ int SendChannelEof(WOLFSSH* ssh, word32 peerChannelId) if (ssh == NULL) ret = WS_BAD_ARGUMENT; + if (ret == WS_SUCCESS) { + if (!IsMessageAllowed(ssh, MSGID_CHANNEL_EOF, WS_MSG_SEND)) { + ret = WS_MSGID_NOT_ALLOWED_E; + } + } + if (ret == WS_SUCCESS) { channel = ChannelFind(ssh, peerChannelId, WS_CHANNEL_ID_PEER); if (channel == NULL) @@ -15853,6 +15896,12 @@ int SendChannelWindowAdjust(WOLFSSH* ssh, word32 channelId, if (ssh == NULL) ret = WS_BAD_ARGUMENT; + if (ret == WS_SUCCESS) { + if (!IsMessageAllowed(ssh, MSGID_CHANNEL_WINDOW_ADJUST, WS_MSG_SEND)) { + ret = WS_MSGID_NOT_ALLOWED_E; + } + } + channel = ChannelFind(ssh, channelId, WS_CHANNEL_ID_SELF); if (channel == NULL) { WLOG(WS_LOG_DEBUG, "Invalid channel"); diff --git a/wolfssh/internal.h b/wolfssh/internal.h index 8aaf65cce..cb4a357c1 100644 --- a/wolfssh/internal.h +++ b/wolfssh/internal.h @@ -1205,6 +1205,10 @@ enum WS_MessageIds { #define CHANNEL_EXTENDED_DATA_STDERR WOLFSSH_EXT_DATA_STDERR +/* Used when checking IsMessageAllowed() to determine if createing and sending + * the message or receiving the message is allowed */ +#define WS_MSG_SEND 1 +#define WS_MSG_RECV 2 /* dynamic memory types */ enum WS_DynamicTypes { @@ -1398,4 +1402,3 @@ enum TerminalModes { #endif #endif /* _WOLFSSH_INTERNAL_H_ */ -