blob: 3afb54a41cb5b8de1139b03db313643c24320d92 [file] [log] [blame]
/*
* Copyright (C) 2019 Igalia, S.L.
*
* This library is free software; you can redistribute it and/or
* modify it under the terms of the GNU Library General Public
* License as published by the Free Software Foundation; either
* version 2 of the License, or (at your option) any later version.
*
* This library is distributed in the hope that it will be useful,
* but WITHOUT ANY WARRANTY; without even the implied warranty of
* MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the GNU
* Library General Public License for more details.
*
* You should have received a copy of the GNU Library General Public License
* along with this library; see the file COPYING.LIB. If not, write to
* the Free Software Foundation, Inc., 51 Franklin Street, Fifth Floor,
* Boston, MA 02110-1301, USA.
*/
#include "config.h"
#include "SocketConnection.h"
#include <cstring>
#include <gio/gio.h>
#include <wtf/ByteOrder.h>
#include <wtf/CheckedArithmetic.h>
#include <wtf/FastMalloc.h>
#include <wtf/RunLoop.h>
namespace WTF {
static const unsigned defaultBufferSize = 4096;
SocketConnection::SocketConnection(GRefPtr<GSocketConnection>&& connection, const MessageHandlers& messageHandlers, gpointer userData)
: m_connection(WTFMove(connection))
, m_messageHandlers(messageHandlers)
, m_userData(userData)
{
relaxAdoptionRequirement();
m_readBuffer.reserveInitialCapacity(defaultBufferSize);
m_writeBuffer.reserveInitialCapacity(defaultBufferSize);
auto* socket = g_socket_connection_get_socket(m_connection.get());
g_socket_set_blocking(socket, FALSE);
m_readMonitor.start(socket, G_IO_IN, RunLoop::current(), [this, protectedThis = makeRef(*this)](GIOCondition condition) -> gboolean {
if (isClosed())
return G_SOURCE_REMOVE;
if (condition & G_IO_HUP || condition & G_IO_ERR || condition & G_IO_NVAL) {
didClose();
return G_SOURCE_REMOVE;
}
ASSERT(condition & G_IO_IN);
return read();
});
}
SocketConnection::~SocketConnection()
{
}
bool SocketConnection::read()
{
while (true) {
size_t previousBufferSize = m_readBuffer.size();
if (m_readBuffer.capacity() - previousBufferSize <= 0)
m_readBuffer.reserveCapacity(m_readBuffer.capacity() + defaultBufferSize);
m_readBuffer.grow(m_readBuffer.capacity());
GUniqueOutPtr<GError> error;
auto bytesRead = g_socket_receive(g_socket_connection_get_socket(m_connection.get()), m_readBuffer.data() + previousBufferSize, m_readBuffer.size() - previousBufferSize, nullptr, &error.outPtr());
if (bytesRead == -1) {
if (g_error_matches(error.get(), G_IO_ERROR, G_IO_ERROR_WOULD_BLOCK)) {
m_readBuffer.shrink(previousBufferSize);
break;
}
g_warning("Error reading from socket connection: %s\n", error->message);
didClose();
return G_SOURCE_REMOVE;
}
if (!bytesRead) {
didClose();
return G_SOURCE_REMOVE;
}
m_readBuffer.shrink(previousBufferSize + bytesRead);
while (readMessage()) { }
if (isClosed())
return G_SOURCE_REMOVE;
}
return G_SOURCE_CONTINUE;
}
enum {
ByteOrderLittleEndian = 1 << 0
};
typedef uint8_t MessageFlags;
static inline bool messageIsByteSwapped(MessageFlags flags)
{
#if G_BYTE_ORDER == G_LITTLE_ENDIAN
return !(flags & ByteOrderLittleEndian);
#else
return (flags & ByteOrderLittleEndian);
#endif
}
bool SocketConnection::readMessage()
{
if (m_readBuffer.size() < sizeof(uint32_t))
return false;
auto* messageData = m_readBuffer.data();
uint32_t bodySizeHeader;
memcpy(&bodySizeHeader, messageData, sizeof(uint32_t));
messageData += sizeof(uint32_t);
bodySizeHeader = ntohl(bodySizeHeader);
Checked<size_t> bodySize = bodySizeHeader;
MessageFlags flags;
memcpy(&flags, messageData, sizeof(MessageFlags));
messageData += sizeof(MessageFlags);
auto messageSize = sizeof(uint32_t) + sizeof(MessageFlags) + bodySize;
if (m_readBuffer.size() < messageSize.unsafeGet())
return false;
Checked<size_t> messageNameLength = strlen(messageData);
messageNameLength++;
if (m_readBuffer.size() < messageNameLength.unsafeGet()) {
ASSERT_NOT_REACHED();
return false;
}
const auto it = m_messageHandlers.find(messageData);
if (it != m_messageHandlers.end()) {
messageData += messageNameLength.unsafeGet();
GRefPtr<GVariant> parameters;
if (!it->value.first.isNull()) {
GUniquePtr<GVariantType> variantType(g_variant_type_new(it->value.first.data()));
size_t parametersSize = bodySize.unsafeGet() - messageNameLength.unsafeGet();
// g_variant_new_from_data() requires the memory to be properly aligned for the type being loaded,
// but it's not possible to know the alignment because g_variant_type_info_query() is not public API.
// Since GLib 2.60 g_variant_new_from_data() already checks the alignment and reallocates the buffer
// in aligned memory only if needed. For older versions we can simply ensure the memory is 8 aligned.
#if GLIB_CHECK_VERSION(2, 60, 0)
parameters = g_variant_new_from_data(variantType.get(), messageData, parametersSize, FALSE, nullptr, nullptr);
#else
auto* alignedMemory = fastAlignedMalloc(8, parametersSize);
memcpy(alignedMemory, messageData, parametersSize);
GRefPtr<GBytes> bytes = g_bytes_new_with_free_func(alignedMemory, parametersSize, [](gpointer data) {
fastAlignedFree(data);
}, alignedMemory);
parameters = g_variant_new_from_bytes(variantType.get(), bytes.get(), FALSE);
#endif
if (messageIsByteSwapped(flags))
parameters = adoptGRef(g_variant_byteswap(parameters.get()));
}
it->value.second(*this, parameters.get(), m_userData);
if (isClosed())
return false;
}
if (m_readBuffer.size() > messageSize.unsafeGet()) {
std::memmove(m_readBuffer.data(), m_readBuffer.data() + messageSize.unsafeGet(), m_readBuffer.size() - messageSize.unsafeGet());
m_readBuffer.shrink(m_readBuffer.size() - messageSize.unsafeGet());
} else
m_readBuffer.shrink(0);
if (m_readBuffer.size() < defaultBufferSize)
m_readBuffer.shrinkCapacity(defaultBufferSize);
return true;
}
void SocketConnection::sendMessage(const char* messageName, GVariant* parameters)
{
GRefPtr<GVariant> adoptedParameters = parameters;
size_t parametersSize = parameters ? g_variant_get_size(parameters) : 0;
CheckedSize messageNameLength = strlen(messageName);
messageNameLength++;
if (UNLIKELY(messageNameLength.hasOverflowed())) {
g_warning("Trying to send message with invalid too long name");
return;
}
Checked<uint32_t, RecordOverflow> bodySize = messageNameLength + parametersSize;
if (UNLIKELY(bodySize.hasOverflowed())) {
g_warning("Trying to send message '%s' with invalid too long body", messageName);
return;
}
size_t previousBufferSize = m_writeBuffer.size();
m_writeBuffer.grow(previousBufferSize + sizeof(uint32_t) + sizeof(MessageFlags) + bodySize.unsafeGet());
auto* messageData = m_writeBuffer.data() + previousBufferSize;
uint32_t bodySizeHeader = htonl(bodySize.unsafeGet());
memcpy(messageData, &bodySizeHeader, sizeof(uint32_t));
messageData += sizeof(uint32_t);
MessageFlags flags = 0;
#if G_BYTE_ORDER == G_LITTLE_ENDIAN
flags |= ByteOrderLittleEndian;
#endif
memcpy(messageData, &flags, sizeof(MessageFlags));
messageData += sizeof(MessageFlags);
memcpy(messageData, messageName, messageNameLength.unsafeGet());
messageData += messageNameLength.unsafeGet();
if (parameters)
memcpy(messageData, g_variant_get_data(parameters), parametersSize);
write();
}
void SocketConnection::write()
{
if (isClosed())
return;
GUniqueOutPtr<GError> error;
auto bytesWritten = g_socket_send(g_socket_connection_get_socket(m_connection.get()), m_writeBuffer.data(), m_writeBuffer.size(), nullptr, &error.outPtr());
if (bytesWritten == -1) {
if (g_error_matches(error.get(), G_IO_ERROR, G_IO_ERROR_WOULD_BLOCK)) {
waitForSocketWritability();
return;
}
g_warning("Error sending message on socket connection: %s\n", error->message);
didClose();
return;
}
if (m_writeBuffer.size() > static_cast<size_t>(bytesWritten)) {
std::memmove(m_writeBuffer.data(), m_writeBuffer.data() + bytesWritten, m_writeBuffer.size() - bytesWritten);
m_writeBuffer.shrink(m_writeBuffer.size() - bytesWritten);
} else
m_writeBuffer.shrink(0);
if (m_writeBuffer.size() < defaultBufferSize)
m_writeBuffer.shrinkCapacity(defaultBufferSize);
if (!m_writeBuffer.isEmpty())
waitForSocketWritability();
}
void SocketConnection::waitForSocketWritability()
{
if (m_writeMonitor.isActive())
return;
m_writeMonitor.start(g_socket_connection_get_socket(m_connection.get()), G_IO_OUT, RunLoop::current(), [this, protectedThis = makeRef(*this)] (GIOCondition condition) -> gboolean {
if (condition & G_IO_OUT) {
// We can't stop the monitor from this lambda, because stop destroys the lambda.
RunLoop::current().dispatch([this, protectedThis] {
m_writeMonitor.stop();
write();
});
}
return G_SOURCE_REMOVE;
});
}
void SocketConnection::close()
{
m_readMonitor.stop();
m_writeMonitor.stop();
m_connection = nullptr;
}
void SocketConnection::didClose()
{
if (isClosed())
return;
close();
ASSERT(m_messageHandlers.contains("DidClose"));
m_messageHandlers.get("DidClose").second(*this, nullptr, m_userData);
}
} // namespace WTF