423 lines
11 KiB
C++
423 lines
11 KiB
C++
|
// Copyright (c) 2012 The Chromium Authors. All rights reserved.
|
||
|
// Use of this source code is governed by a BSD-style license that can be
|
||
|
// found in the LICENSE file.
|
||
|
|
||
|
#include "tools/android/forwarder2/socket.h"
|
||
|
|
||
|
#include <arpa/inet.h>
|
||
|
#include <fcntl.h>
|
||
|
#include <netdb.h>
|
||
|
#include <netinet/in.h>
|
||
|
#include <stdio.h>
|
||
|
#include <string.h>
|
||
|
#include <sys/socket.h>
|
||
|
#include <sys/types.h>
|
||
|
#include <unistd.h>
|
||
|
|
||
|
#include "base/logging.h"
|
||
|
#include "base/posix/eintr_wrapper.h"
|
||
|
#include "base/safe_strerror_posix.h"
|
||
|
#include "tools/android/common/net.h"
|
||
|
#include "tools/android/forwarder2/common.h"
|
||
|
|
||
|
namespace {
|
||
|
const int kNoTimeout = -1;
|
||
|
const int kConnectTimeOut = 10; // Seconds.
|
||
|
|
||
|
bool FamilyIsTCP(int family) {
|
||
|
return family == AF_INET || family == AF_INET6;
|
||
|
}
|
||
|
} // namespace
|
||
|
|
||
|
namespace forwarder2 {
|
||
|
|
||
|
bool Socket::BindUnix(const std::string& path) {
|
||
|
errno = 0;
|
||
|
if (!InitUnixSocket(path) || !BindAndListen()) {
|
||
|
Close();
|
||
|
return false;
|
||
|
}
|
||
|
return true;
|
||
|
}
|
||
|
|
||
|
bool Socket::BindTcp(const std::string& host, int port) {
|
||
|
errno = 0;
|
||
|
if (!InitTcpSocket(host, port) || !BindAndListen()) {
|
||
|
Close();
|
||
|
return false;
|
||
|
}
|
||
|
return true;
|
||
|
}
|
||
|
|
||
|
bool Socket::ConnectUnix(const std::string& path) {
|
||
|
errno = 0;
|
||
|
if (!InitUnixSocket(path) || !Connect()) {
|
||
|
Close();
|
||
|
return false;
|
||
|
}
|
||
|
return true;
|
||
|
}
|
||
|
|
||
|
bool Socket::ConnectTcp(const std::string& host, int port) {
|
||
|
errno = 0;
|
||
|
if (!InitTcpSocket(host, port) || !Connect()) {
|
||
|
Close();
|
||
|
return false;
|
||
|
}
|
||
|
return true;
|
||
|
}
|
||
|
|
||
|
Socket::Socket()
|
||
|
: socket_(-1),
|
||
|
port_(0),
|
||
|
socket_error_(false),
|
||
|
family_(AF_INET),
|
||
|
addr_ptr_(reinterpret_cast<sockaddr*>(&addr_.addr4)),
|
||
|
addr_len_(sizeof(sockaddr)) {
|
||
|
memset(&addr_, 0, sizeof(addr_));
|
||
|
}
|
||
|
|
||
|
Socket::~Socket() {
|
||
|
Close();
|
||
|
}
|
||
|
|
||
|
void Socket::Shutdown() {
|
||
|
if (!IsClosed()) {
|
||
|
PRESERVE_ERRNO_HANDLE_EINTR(shutdown(socket_, SHUT_RDWR));
|
||
|
}
|
||
|
}
|
||
|
|
||
|
void Socket::Close() {
|
||
|
if (!IsClosed()) {
|
||
|
CloseFD(socket_);
|
||
|
socket_ = -1;
|
||
|
}
|
||
|
}
|
||
|
|
||
|
bool Socket::InitSocketInternal() {
|
||
|
socket_ = socket(family_, SOCK_STREAM, 0);
|
||
|
if (socket_ < 0)
|
||
|
return false;
|
||
|
tools::DisableNagle(socket_);
|
||
|
int reuse_addr = 1;
|
||
|
setsockopt(socket_, SOL_SOCKET, SO_REUSEADDR,
|
||
|
&reuse_addr, sizeof(reuse_addr));
|
||
|
return true;
|
||
|
}
|
||
|
|
||
|
bool Socket::InitUnixSocket(const std::string& path) {
|
||
|
static const size_t kPathMax = sizeof(addr_.addr_un.sun_path);
|
||
|
// For abstract sockets we need one extra byte for the leading zero.
|
||
|
if (path.size() + 2 /* '\0' */ > kPathMax) {
|
||
|
LOG(ERROR) << "The provided path is too big to create a unix "
|
||
|
<< "domain socket: " << path;
|
||
|
return false;
|
||
|
}
|
||
|
family_ = PF_UNIX;
|
||
|
addr_.addr_un.sun_family = family_;
|
||
|
// Copied from net/socket/unix_domain_socket_posix.cc
|
||
|
// Convert the path given into abstract socket name. It must start with
|
||
|
// the '\0' character, so we are adding it. |addr_len| must specify the
|
||
|
// length of the structure exactly, as potentially the socket name may
|
||
|
// have '\0' characters embedded (although we don't support this).
|
||
|
// Note that addr_.addr_un.sun_path is already zero initialized.
|
||
|
memcpy(addr_.addr_un.sun_path + 1, path.c_str(), path.size());
|
||
|
addr_len_ = path.size() + offsetof(struct sockaddr_un, sun_path) + 1;
|
||
|
addr_ptr_ = reinterpret_cast<sockaddr*>(&addr_.addr_un);
|
||
|
return InitSocketInternal();
|
||
|
}
|
||
|
|
||
|
bool Socket::InitTcpSocket(const std::string& host, int port) {
|
||
|
port_ = port;
|
||
|
if (host.empty()) {
|
||
|
// Use localhost: INADDR_LOOPBACK
|
||
|
family_ = AF_INET;
|
||
|
addr_.addr4.sin_family = family_;
|
||
|
addr_.addr4.sin_addr.s_addr = htonl(INADDR_LOOPBACK);
|
||
|
} else if (!Resolve(host)) {
|
||
|
return false;
|
||
|
}
|
||
|
CHECK(FamilyIsTCP(family_)) << "Invalid socket family.";
|
||
|
if (family_ == AF_INET) {
|
||
|
addr_.addr4.sin_port = htons(port_);
|
||
|
addr_ptr_ = reinterpret_cast<sockaddr*>(&addr_.addr4);
|
||
|
addr_len_ = sizeof(addr_.addr4);
|
||
|
} else if (family_ == AF_INET6) {
|
||
|
addr_.addr6.sin6_port = htons(port_);
|
||
|
addr_ptr_ = reinterpret_cast<sockaddr*>(&addr_.addr6);
|
||
|
addr_len_ = sizeof(addr_.addr6);
|
||
|
}
|
||
|
return InitSocketInternal();
|
||
|
}
|
||
|
|
||
|
bool Socket::BindAndListen() {
|
||
|
errno = 0;
|
||
|
if (HANDLE_EINTR(bind(socket_, addr_ptr_, addr_len_)) < 0 ||
|
||
|
HANDLE_EINTR(listen(socket_, SOMAXCONN)) < 0) {
|
||
|
SetSocketError();
|
||
|
return false;
|
||
|
}
|
||
|
if (port_ == 0 && FamilyIsTCP(family_)) {
|
||
|
SockAddr addr;
|
||
|
memset(&addr, 0, sizeof(addr));
|
||
|
socklen_t addrlen = 0;
|
||
|
sockaddr* addr_ptr = NULL;
|
||
|
uint16* port_ptr = NULL;
|
||
|
if (family_ == AF_INET) {
|
||
|
addr_ptr = reinterpret_cast<sockaddr*>(&addr.addr4);
|
||
|
port_ptr = &addr.addr4.sin_port;
|
||
|
addrlen = sizeof(addr.addr4);
|
||
|
} else if (family_ == AF_INET6) {
|
||
|
addr_ptr = reinterpret_cast<sockaddr*>(&addr.addr6);
|
||
|
port_ptr = &addr.addr6.sin6_port;
|
||
|
addrlen = sizeof(addr.addr6);
|
||
|
}
|
||
|
errno = 0;
|
||
|
if (getsockname(socket_, addr_ptr, &addrlen) != 0) {
|
||
|
LOG(ERROR) << "getsockname error: " << safe_strerror(errno);;
|
||
|
SetSocketError();
|
||
|
return false;
|
||
|
}
|
||
|
port_ = ntohs(*port_ptr);
|
||
|
}
|
||
|
return true;
|
||
|
}
|
||
|
|
||
|
bool Socket::Accept(Socket* new_socket) {
|
||
|
DCHECK(new_socket != NULL);
|
||
|
if (!WaitForEvent(READ, kNoTimeout)) {
|
||
|
SetSocketError();
|
||
|
return false;
|
||
|
}
|
||
|
errno = 0;
|
||
|
int new_socket_fd = HANDLE_EINTR(accept(socket_, NULL, NULL));
|
||
|
if (new_socket_fd < 0) {
|
||
|
SetSocketError();
|
||
|
return false;
|
||
|
}
|
||
|
|
||
|
tools::DisableNagle(new_socket_fd);
|
||
|
new_socket->socket_ = new_socket_fd;
|
||
|
return true;
|
||
|
}
|
||
|
|
||
|
bool Socket::Connect() {
|
||
|
// Set non-block because we use select for connect.
|
||
|
const int kFlags = fcntl(socket_, F_GETFL);
|
||
|
DCHECK(!(kFlags & O_NONBLOCK));
|
||
|
fcntl(socket_, F_SETFL, kFlags | O_NONBLOCK);
|
||
|
errno = 0;
|
||
|
if (HANDLE_EINTR(connect(socket_, addr_ptr_, addr_len_)) < 0 &&
|
||
|
errno != EINPROGRESS) {
|
||
|
SetSocketError();
|
||
|
PRESERVE_ERRNO_HANDLE_EINTR(fcntl(socket_, F_SETFL, kFlags));
|
||
|
return false;
|
||
|
}
|
||
|
// Wait for connection to complete, or receive a notification.
|
||
|
if (!WaitForEvent(WRITE, kConnectTimeOut)) {
|
||
|
SetSocketError();
|
||
|
PRESERVE_ERRNO_HANDLE_EINTR(fcntl(socket_, F_SETFL, kFlags));
|
||
|
return false;
|
||
|
}
|
||
|
int socket_errno;
|
||
|
socklen_t opt_len = sizeof(socket_errno);
|
||
|
if (getsockopt(socket_, SOL_SOCKET, SO_ERROR, &socket_errno, &opt_len) < 0) {
|
||
|
LOG(ERROR) << "getsockopt(): " << safe_strerror(errno);
|
||
|
SetSocketError();
|
||
|
PRESERVE_ERRNO_HANDLE_EINTR(fcntl(socket_, F_SETFL, kFlags));
|
||
|
return false;
|
||
|
}
|
||
|
if (socket_errno != 0) {
|
||
|
LOG(ERROR) << "Could not connect to host: " << safe_strerror(socket_errno);
|
||
|
SetSocketError();
|
||
|
PRESERVE_ERRNO_HANDLE_EINTR(fcntl(socket_, F_SETFL, kFlags));
|
||
|
return false;
|
||
|
}
|
||
|
fcntl(socket_, F_SETFL, kFlags);
|
||
|
return true;
|
||
|
}
|
||
|
|
||
|
bool Socket::Resolve(const std::string& host) {
|
||
|
struct addrinfo hints;
|
||
|
struct addrinfo* res;
|
||
|
memset(&hints, 0, sizeof(hints));
|
||
|
hints.ai_family = AF_UNSPEC;
|
||
|
hints.ai_socktype = SOCK_STREAM;
|
||
|
hints.ai_flags |= AI_CANONNAME;
|
||
|
|
||
|
int errcode = getaddrinfo(host.c_str(), NULL, &hints, &res);
|
||
|
if (errcode != 0) {
|
||
|
SetSocketError();
|
||
|
freeaddrinfo(res);
|
||
|
return false;
|
||
|
}
|
||
|
family_ = res->ai_family;
|
||
|
switch (res->ai_family) {
|
||
|
case AF_INET:
|
||
|
memcpy(&addr_.addr4,
|
||
|
reinterpret_cast<sockaddr_in*>(res->ai_addr),
|
||
|
sizeof(sockaddr_in));
|
||
|
break;
|
||
|
case AF_INET6:
|
||
|
memcpy(&addr_.addr6,
|
||
|
reinterpret_cast<sockaddr_in6*>(res->ai_addr),
|
||
|
sizeof(sockaddr_in6));
|
||
|
break;
|
||
|
}
|
||
|
freeaddrinfo(res);
|
||
|
return true;
|
||
|
}
|
||
|
|
||
|
int Socket::GetPort() {
|
||
|
if (!FamilyIsTCP(family_)) {
|
||
|
LOG(ERROR) << "Can't call GetPort() on an unix domain socket.";
|
||
|
return 0;
|
||
|
}
|
||
|
return port_;
|
||
|
}
|
||
|
|
||
|
bool Socket::IsFdInSet(const fd_set& fds) const {
|
||
|
if (IsClosed())
|
||
|
return false;
|
||
|
return FD_ISSET(socket_, &fds);
|
||
|
}
|
||
|
|
||
|
bool Socket::AddFdToSet(fd_set* fds) const {
|
||
|
if (IsClosed())
|
||
|
return false;
|
||
|
FD_SET(socket_, fds);
|
||
|
return true;
|
||
|
}
|
||
|
|
||
|
int Socket::ReadNumBytes(void* buffer, size_t num_bytes) {
|
||
|
int bytes_read = 0;
|
||
|
int ret = 1;
|
||
|
while (bytes_read < num_bytes && ret > 0) {
|
||
|
ret = Read(static_cast<char*>(buffer) + bytes_read, num_bytes - bytes_read);
|
||
|
if (ret >= 0)
|
||
|
bytes_read += ret;
|
||
|
}
|
||
|
return bytes_read;
|
||
|
}
|
||
|
|
||
|
void Socket::SetSocketError() {
|
||
|
socket_error_ = true;
|
||
|
// We never use non-blocking socket.
|
||
|
DCHECK(errno != EAGAIN && errno != EWOULDBLOCK);
|
||
|
Close();
|
||
|
}
|
||
|
|
||
|
int Socket::Read(void* buffer, size_t buffer_size) {
|
||
|
if (!WaitForEvent(READ, kNoTimeout)) {
|
||
|
SetSocketError();
|
||
|
return 0;
|
||
|
}
|
||
|
int ret = HANDLE_EINTR(read(socket_, buffer, buffer_size));
|
||
|
if (ret < 0)
|
||
|
SetSocketError();
|
||
|
return ret;
|
||
|
}
|
||
|
|
||
|
int Socket::Write(const void* buffer, size_t count) {
|
||
|
int ret = HANDLE_EINTR(send(socket_, buffer, count, MSG_NOSIGNAL));
|
||
|
if (ret < 0)
|
||
|
SetSocketError();
|
||
|
return ret;
|
||
|
}
|
||
|
|
||
|
int Socket::WriteString(const std::string& buffer) {
|
||
|
return WriteNumBytes(buffer.c_str(), buffer.size());
|
||
|
}
|
||
|
|
||
|
void Socket::AddEventFd(int event_fd) {
|
||
|
Event event;
|
||
|
event.fd = event_fd;
|
||
|
event.was_fired = false;
|
||
|
events_.push_back(event);
|
||
|
}
|
||
|
|
||
|
bool Socket::DidReceiveEventOnFd(int fd) const {
|
||
|
for (size_t i = 0; i < events_.size(); ++i)
|
||
|
if (events_[i].fd == fd)
|
||
|
return events_[i].was_fired;
|
||
|
return false;
|
||
|
}
|
||
|
|
||
|
bool Socket::DidReceiveEvent() const {
|
||
|
for (size_t i = 0; i < events_.size(); ++i)
|
||
|
if (events_[i].was_fired)
|
||
|
return true;
|
||
|
return false;
|
||
|
}
|
||
|
|
||
|
int Socket::WriteNumBytes(const void* buffer, size_t num_bytes) {
|
||
|
int bytes_written = 0;
|
||
|
int ret = 1;
|
||
|
while (bytes_written < num_bytes && ret > 0) {
|
||
|
ret = Write(static_cast<const char*>(buffer) + bytes_written,
|
||
|
num_bytes - bytes_written);
|
||
|
if (ret >= 0)
|
||
|
bytes_written += ret;
|
||
|
}
|
||
|
return bytes_written;
|
||
|
}
|
||
|
|
||
|
bool Socket::WaitForEvent(EventType type, int timeout_secs) {
|
||
|
if (events_.empty() || socket_ == -1)
|
||
|
return true;
|
||
|
fd_set read_fds;
|
||
|
fd_set write_fds;
|
||
|
FD_ZERO(&read_fds);
|
||
|
FD_ZERO(&write_fds);
|
||
|
if (type == READ)
|
||
|
FD_SET(socket_, &read_fds);
|
||
|
else
|
||
|
FD_SET(socket_, &write_fds);
|
||
|
for (size_t i = 0; i < events_.size(); ++i)
|
||
|
FD_SET(events_[i].fd, &read_fds);
|
||
|
timeval tv = {};
|
||
|
timeval* tv_ptr = NULL;
|
||
|
if (timeout_secs > 0) {
|
||
|
tv.tv_sec = timeout_secs;
|
||
|
tv.tv_usec = 0;
|
||
|
tv_ptr = &tv;
|
||
|
}
|
||
|
int max_fd = socket_;
|
||
|
for (size_t i = 0; i < events_.size(); ++i)
|
||
|
if (events_[i].fd > max_fd)
|
||
|
max_fd = events_[i].fd;
|
||
|
if (HANDLE_EINTR(
|
||
|
select(max_fd + 1, &read_fds, &write_fds, NULL, tv_ptr)) <= 0) {
|
||
|
return false;
|
||
|
}
|
||
|
bool event_was_fired = false;
|
||
|
for (size_t i = 0; i < events_.size(); ++i) {
|
||
|
if (FD_ISSET(events_[i].fd, &read_fds)) {
|
||
|
events_[i].was_fired = true;
|
||
|
event_was_fired = true;
|
||
|
}
|
||
|
}
|
||
|
return !event_was_fired;
|
||
|
}
|
||
|
|
||
|
// static
|
||
|
int Socket::GetHighestFileDescriptor(const Socket& s1, const Socket& s2) {
|
||
|
return std::max(s1.socket_, s2.socket_);
|
||
|
}
|
||
|
|
||
|
// static
|
||
|
pid_t Socket::GetUnixDomainSocketProcessOwner(const std::string& path) {
|
||
|
Socket socket;
|
||
|
if (!socket.ConnectUnix(path))
|
||
|
return -1;
|
||
|
ucred ucred;
|
||
|
socklen_t len = sizeof(ucred);
|
||
|
if (getsockopt(socket.socket_, SOL_SOCKET, SO_PEERCRED, &ucred, &len) == -1) {
|
||
|
CHECK_NE(ENOPROTOOPT, errno);
|
||
|
return -1;
|
||
|
}
|
||
|
return ucred.pid;
|
||
|
}
|
||
|
|
||
|
} // namespace forwarder2
|