diff --git a/rdo.c b/rdo.c index 8cde84c..e91f4c0 100644 --- a/rdo.c +++ b/rdo.c @@ -7,6 +7,7 @@ #include #include #include +#include "sessions.h" void getconf(FILE* fp, const char* entry, char* result, size_t len_result) { char* line = NULL; @@ -45,8 +46,8 @@ int runprog(int argc, char** argv) { } int main(int argc, char** argv) { - char username[64], wrong_pw_sleep[64], password[128]; - unsigned int sleep_ms, tries; + char username[64], wrong_pw_sleep[64], session_ttl[64], password[128]; + unsigned int sleep_ms, tries, ts_ttl; if (argc < 2) errx(1, "Please specify a program to run"); @@ -62,10 +63,15 @@ int main(int argc, char** argv) { getconf(fp, "username", username, sizeof(username)); getconf(fp, "wrong_pw_sleep", wrong_pw_sleep, sizeof(wrong_pw_sleep)); + getconf(fp, "session_ttl", session_ttl, sizeof(session_ttl)); sleep_ms = atoi(wrong_pw_sleep) * 1000; + ts_ttl = atoi(session_ttl) * 60 * 100; fclose(fp); + if (getsession(getppid(), ts_ttl) == 0) + return runprog(argc, argv); + struct passwd* p = getpwnam(username); if (!p) err(1, "Could not get user info"); @@ -84,8 +90,10 @@ int main(int argc, char** argv) { if (!readpassphrase("(rdo) Password: ", password, sizeof(password), RPP_REQUIRE_TTY)) err(1, "Could not get passphrase"); - if (strcmp(shadowEntry->sp_pwdp, crypt(password, shadowEntry->sp_pwdp)) == 0) + if (strcmp(shadowEntry->sp_pwdp, crypt(password, shadowEntry->sp_pwdp)) == 0) { + setsession(getppid(), ts_ttl); return runprog(argc, argv); + } usleep(sleep_ms); fprintf(stderr, "Wrong password.\n"); diff --git a/rdo_sample.conf b/rdo_sample.conf index 3c00c78..94842f1 100644 --- a/rdo_sample.conf +++ b/rdo_sample.conf @@ -1,2 +1,3 @@ username=sw1tchbl4d3 wrong_pw_sleep=1000 +session_ttl=5 diff --git a/sessions.h b/sessions.h new file mode 100644 index 0000000..e4f798a --- /dev/null +++ b/sessions.h @@ -0,0 +1,139 @@ +#pragma once + +#include +#include +#include +#include +#include +#include + +int getpstartts(int pid, unsigned long long* startts) { + char path[255], fc[1024]; + char* ptr = fc; + + snprintf(path, sizeof(path), "/proc/%d/stat", pid); + int fd = open(path, O_RDONLY); + + if (fd < 0) + err(1, "Could not open %s", path); + + int bytes_read = read(fd, fc, sizeof(fc)); + close(fd); + + if (memchr(ptr, '\0', bytes_read) != NULL) + return -1; + + ptr = strrchr(fc, ')'); + + char* token = strtok(ptr, " "); + + for (short i = 0; i<20 && token; i++) + token = strtok(NULL, " "); + + if (!token) + return -1; + + unsigned long long temp_ts = strtoull(token, NULL, 10); + if (temp_ts == 0 || temp_ts == ULLONG_MAX) + return -1; + + *startts = temp_ts; + return 0; +} + +int gethandle(int recur) { + if (recur >= 2) + errx(1, "Too many recursions in gethandle()"); + + struct stat st; + int fd = open("/run/rdo", O_RDONLY, O_DIRECTORY | O_NOFOLLOW); + + if (fd < 0) { + if (errno == ENOENT) { + if (mkdir("/run/rdo", 0700) < 0) + err(1, "Could not create /run/rdo"); + return gethandle(++recur); + } + else + err(1, "Could not open /run/rdo"); + } else { + if (fstat(fd, &st) < 0) + err(1, "Could not fstat /run/rdo"); + + if (st.st_uid != 0 || st.st_mode != (0700 | S_IFDIR)) + return -1; + } + + return fd; +} + +void setsession(int pid, unsigned int ts_ttl) { + if (ts_ttl == 0) + return; + + unsigned long long startts; + char path[1024], ts_str[32]; + + int dirfd = gethandle(0); + if (dirfd < 0 && errno == 0) + return; + + if (getpstartts(pid, &startts) < 0) + return; + + snprintf(path, sizeof(path), "/run/rdo/%d-%llu", pid, startts); + + int fd = openat(dirfd, path, O_CREAT | O_EXCL | O_WRONLY, 0700); + if (fd < 0) { + if (errno == EEXIST) + return; + err(1, "Could not open %s", path); + } + + snprintf(ts_str, sizeof(ts_str), "%llu", (unsigned long long)time(NULL)); + + if (write(fd, ts_str, sizeof(ts_str)) < 0) + err(1, "Could not write to %s", path); + + close(fd); + + return; +} + +int getsession(int pid, unsigned int ts_ttl) { + if (ts_ttl == 0) + return -1; + + unsigned long long startts, current; + char path[1024], ts_str[32]; + + int dirfd = gethandle(0); + if (dirfd < 0 && errno == 0) + return -1; + + if (getpstartts(pid, &startts) < 0) + return -1; + + snprintf(path, sizeof(path), "/run/rdo/%d-%llu", pid, startts); + + int fd = openat(dirfd, path, O_RDONLY); + if (fd < 0) { + if (errno == ENOENT) + return -1; + err(1, "Could not open %s", path); + } + + printf("%d", fd); + if (read(fd, ts_str, sizeof(ts_str)) < 0) + err(1, "Could not read %s", path); + + startts = strtoull(ts_str, NULL, 10); + current = time(NULL); + + if (current - startts > ts_ttl) { + unlink(path); + return -1; + } + + return 0; +}