#define MODULE_NAME "notify"
#define RECEIVE_BUFFER_SIZE 2048
+#if defined(__linux__)
+#define PLATFORM_LINUX 1
+#else
+#include <sys/ucred.h>
+#define PLATFORM_LINUX 0
+#endif
+
+int create_nonblocking_socket(int domain, int type, int protocol);
+
static PyObject *NotifySocketError;
+int socket_nonblocking(int domain, int type, int protocol) {
+#if PLATFORM_LINUX
+ return socket(domain, type | SOCK_NONBLOCK | SOCK_CLOEXEC, protocol);
+#else
+ int sockfd = socket(domain, type, protocol);
+
+ /* set the socket to nonblocking mode */
+ int flags = fcntl(sockfd, F_GETFL, 0);
+ if (flags == -1)
+ return -1;
+ if (fcntl(sockfd, F_SETFL, flags | O_NONBLOCK) < 0)
+ return -1;
+
+ int fdflags = fcntl(sockfd, F_GETFD);
+ if (fdflags != -1) {
+ if (fcntl(sockfd, F_SETFD, fdflags | FD_CLOEXEC) < 0)
+ return -1;
+ }
+ return sockfd;
+#endif
+}
+
static PyObject *init_control_socket(PyObject *self, PyObject *args)
{
/* create socket */
- int controlfd = socket(AF_UNIX, SOCK_DGRAM | SOCK_NONBLOCK, 0);
+ int controlfd = socket_nonblocking(AF_UNIX, SOCK_DGRAM, 0);
if (controlfd == -1) goto fail_errno;
/* construct the address; sd_notify() requires that the path is absolute */
int res = bind(controlfd, (struct sockaddr *)&server_addr, sizeof(server_addr));
if (res < 0) goto fail_errno;
+#if PLATFORM_LINUX
/* make sure that we get credentials with messages */
int data = (int)true;
res = setsockopt(controlfd, SOL_SOCKET, SO_PASSCRED, &data, sizeof(data));
if (res < 0) goto fail_errno;
+#endif
/* store the name of the socket in env to fake systemd */
char *old_value = getenv(NOTIFY_SOCKET_NAME);
if (old_value != NULL) {
printf("[notify_socket] warning, running under systemd and overwriting $%s\n",
- NOTIFY_SOCKET_NAME);
+ NOTIFY_SOCKET_NAME);
// fixme
}
char place_for_data[RECEIVE_BUFFER_SIZE];
bzero(&place_for_data, sizeof(place_for_data));
struct iovec iov = { .iov_base = &place_for_data,
- .iov_len = sizeof(place_for_data) };
+ .iov_len = sizeof(place_for_data) };
msg.msg_iov = &iov;
msg.msg_iovlen = 1;
+#if PLATFORM_LINUX
char cmsg[CMSG_SPACE(sizeof(struct ucred))];
+#else
+ char cmsg[0];
+#endif
msg.msg_control = cmsg;
msg.msg_controllen = sizeof(cmsg);
+ pid_t pid = -1;
+#if PLATFORM_LINUX
/* Receive real plus ancillary data */
int len = recvmsg(controlfd, &msg, 0);
if (len == -1) {
/* read the sender pid */
struct cmsghdr *cmsgp = CMSG_FIRSTHDR(&msg);
- pid_t pid = -1;
while (cmsgp != NULL) {
if (cmsgp->cmsg_type == SCM_CREDENTIALS) {
- if (
- cmsgp->cmsg_len != CMSG_LEN(sizeof(struct ucred)) ||
- cmsgp->cmsg_level != SOL_SOCKET
- ) {
- printf("[notify_socket] invalid cmsg data, ignoring\n");
- Py_RETURN_NONE;
- }
+ if (
+ cmsgp->cmsg_len != CMSG_LEN(sizeof(struct ucred)) ||
+ cmsgp->cmsg_level != SOL_SOCKET
+ ) {
+ printf("[notify_socket] invalid cmsg data, ignoring\n");
+ Py_RETURN_NONE;
+ }
struct ucred cred;
memcpy(&cred, CMSG_DATA(cmsgp), sizeof(cred));
}
cmsgp = CMSG_NXTHDR(&msg, cmsgp);
}
+#else
+ struct xucred cred;
+ socklen_t len = sizeof(cred);
+ getsockopt(controlfd, 0, LOCAL_PEERCRED, &cred, &len);
+#endif
+
if (pid == -1) {
printf("[notify_socket] ignoring received data without credentials: %s\n",
- place_for_data);
+ place_for_data);
Py_RETURN_NONE;
}
PyModuleDef_HEAD_INIT, MODULE_NAME, /* name of module */
NULL, /* module documentation, may be NULL */
-1, /* size of per-interpreter state of the module,
- or -1 if the module keeps state in global variables. */
+ or -1 if the module keeps state in global variables. */
NotifyMethods
};