unplugged-vendor/system/bt/mediatek/utils/socket_server.cc

277 lines
8.1 KiB
C++

/*
* Copyright (C) 2017 The Android Open Source Project
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
#include "socket_server.h"
#include <poll.h>
#include <cassert>
#include <cinttypes>
#include <cstdlib>
#include <map>
#include <mutex>
#include <thread>
#include <cutils/fs.h>
#include <cutils/sockets.h>
#include "chre_log.h"
namespace android {
namespace chre {
SocketServer::SocketServer()
: mSockFd(INVALID_SOCKET),
mNextClientId(1),
mPollFds{},
mClientMessageCallback(nullptr),
signal_received_{false} {
// Initialize the socket fds field for all inactive client slots to -1, so
// poll skips over it, and we don't attempt to send on it
for (size_t i = 0; i <= kMaxActiveClients; i++) {
mPollFds[i].fd = -1;
mPollFds[i].events = POLLIN;
}
}
void SocketServer::run(const char *socketName, bool allowSocketCreation,
ClientMessageCallback clientMessageCallback) {
mClientMessageCallback = clientMessageCallback;
mSockFd = android_get_control_socket(socketName);
if (mSockFd == INVALID_SOCKET && allowSocketCreation) {
LOGI("Didn't inherit socket, creating...");
mSockFd = socket_local_server(socketName,
ANDROID_SOCKET_NAMESPACE_ABSTRACT,
SOCK_STREAM);
}
LOGI(" run: mSockFd[%d]",(int)mSockFd);
if (mSockFd == INVALID_SOCKET) {
LOGE("Couldn't get/create socket");
} else {
int ret = listen(mSockFd, kMaxPendingConnectionRequests);
if (ret < 0) {
LOG_ERROR("Couldn't listen on socket", errno);
} else {
serviceSocket();
}
{
std::lock_guard<std::mutex> lock(mClientsMutex);
for (const auto& pair : mClients) {
int clientSocket = pair.first;
if (close(clientSocket) != 0) {
LOGI("Couldn't close client %" PRIu16 "'s socket: %s",
pair.second.clientId, strerror(errno));
}
}
mClients.clear();
LOGI("All clients are cleared");
}
socket_close(mSockFd);
mSockFd = INVALID_SOCKET;
std::atomic_exchange(&signal_received_, false);
}
}
void SocketServer::sendToAllClients(const void *data, size_t length) {
std::lock_guard<std::mutex> lock(mClientsMutex);
int deliveredCount = 0;
for (const auto& pair : mClients) {
int clientSocket = pair.first;
uint16_t clientId = pair.second.clientId;
if (sendToClientSocket(data, length, clientSocket, clientId)) {
deliveredCount++;
} else {
LOG_ERROR("send fail caused by: ", errno);
break;
}
}
if (deliveredCount == 0) {
LOGW("Got message but didn't deliver to any clients");
}
}
bool SocketServer::sendToClientById(const void *data, size_t length,
uint16_t clientId) {
std::lock_guard<std::mutex> lock(mClientsMutex);
bool sent = false;
for (const auto& pair : mClients) {
uint16_t thisClientId = pair.second.clientId;
if (thisClientId == clientId) {
int clientSocket = pair.first;
sent = sendToClientSocket(data, length, clientSocket, thisClientId);
break;
}
}
return sent;
}
void SocketServer::acceptClientConnection() {
struct sockaddr addr;
socklen_t addr_len = sizeof(addr);
int clientSocket = accept(mSockFd, &addr, &addr_len);
LOGI(" acceptClientConnection: clientSocket[%d] mSockFd[%d]",(int)clientSocket, (int)mSockFd);
if (clientSocket < 0) {
LOG_ERROR("Couldn't accept client connection", errno);
} else if (mClients.size() >= kMaxActiveClients) {
LOGW("Rejecting client request - maximum number of clients reached");
close(clientSocket);
} else {
ClientData clientData;
clientData.clientId = mNextClientId++;
// We currently don't handle wraparound - if we're getting this many
// connects/disconnects, then something is wrong.
// TODO: can handle this properly by iterating over the existing clients to
// avoid a conflict.
if (clientData.clientId == 0) {
LOGE("Couldn't allocate client ID");
std::exit(-1);
}
bool slotFound = false;
for (size_t i = kClientStartIndex; i <= kMaxActiveClients; i++) {
if (mPollFds[i].fd < 0) {
mPollFds[i].fd = clientSocket;
slotFound = true;
break;
}
}
if (!slotFound) {
LOGE("Couldn't find slot for client!");
assert(slotFound);
close(clientSocket);
} else {
{
std::lock_guard<std::mutex> lock(mClientsMutex);
mClients[clientSocket] = clientData;
}
LOGI("Accepted new client connection (count %zu), assigned client ID %"
PRIu16, mClients.size(), clientData.clientId);
}
}
}
void SocketServer::handleClientData(int clientSocket) {
const ClientData& clientData = mClients[clientSocket];
uint16_t clientId = clientData.clientId;
uint8_t buffer[kMaxPacketSize];
ssize_t packetSize = TEMP_FAILURE_RETRY(
recv(clientSocket, buffer, sizeof(buffer), MSG_DONTWAIT));
if (packetSize < 0) {
LOGE("Couldn't get packet from client %" PRIu16 ": %s", clientId,
strerror(errno));
} else if (packetSize == 0) {
LOGI("Client %" PRIu16 " disconnected", clientId);
disconnectClient(clientSocket);
} else {
LOGV("Got %zd byte packet from client %" PRIu16, packetSize, clientId);
mClientMessageCallback(clientId, buffer, packetSize);
}
}
void SocketServer::disconnectClient(int clientSocket) {
LOGI(" disconnectClient: clientSocket[%d] ",(int)clientSocket);
{
std::lock_guard<std::mutex> lock(mClientsMutex);
mClients.erase(clientSocket);
}
close(clientSocket);
bool removed = false;
for (size_t i = kClientStartIndex; i <= kMaxActiveClients; i++) {
if (mPollFds[i].fd == clientSocket) {
mPollFds[i].fd = -1;
removed = true;
break;
}
}
if (!removed) {
LOGE("Out of sync");
assert(removed);
}
}
bool SocketServer::sendToClientSocket(const void *data, size_t length,
int clientSocket, uint16_t clientId) {
errno = 0;
ssize_t bytesSent = TEMP_FAILURE_RETRY(send(clientSocket, data, length, 0));
if (bytesSent < 0) {
LOGE("Error sending packet of size %zu to client %" PRIu16 ": %s",
length, clientId, strerror(errno));
} else if (bytesSent == 0) {
LOGW("Client %" PRIu16 " disconnected before message could be delivered",
clientId);
} else {
LOGV("Delivered message of size %zu bytes to client %" PRIu16, length,
clientId);
}
return (bytesSent > 0);
}
void SocketServer::serviceSocket() {
static_assert(kListenIndex == 0, "Code assumes that the first index is "
"always the listen socket");
mPollFds[kListenIndex].fd = mSockFd;
mPollFds[kListenIndex].events = POLLIN;
LOGI("serviceSocket: Ready to accept connections: fd[%d] kListenIndex[%d]",(int)mPollFds[kListenIndex].fd, (int)kListenIndex);
while (!signal_received_) {
int ret = ppoll(mPollFds, 1 + kMaxActiveClients, nullptr, nullptr);
if (ret == -1) {
LOGI("Exiting poll loop: %s", strerror(errno));
break;
}
if (mPollFds[kListenIndex].revents & POLLIN) {
acceptClientConnection();
}
for (size_t i = kClientStartIndex; i <= kMaxActiveClients; i++) {
if (mPollFds[i].fd < 0) {
continue;
}
if (mPollFds[i].revents & POLLIN) {
handleClientData(mPollFds[i].fd);
}
}
}
}
void SocketServer::StopPoll() {
{
std::lock_guard<std::mutex> lock(mClientsMutex);
LOGI(" StopPoll: fd[%d] kListenIndex[%d]",(int)mPollFds[kListenIndex].fd, (int)kListenIndex);
shutdown(mPollFds[kListenIndex].fd, SHUT_RDWR);
std::atomic_exchange(&signal_received_, true);
}
}
} // namespace chre
} // namespace android