diff --git a/rdo.c b/rdo.c index eee4aef..8c74445 100644 --- a/rdo.c +++ b/rdo.c @@ -53,7 +53,11 @@ int main(int argc, char** argv) { char groupname[64], wrong_pw_sleep[64], session_ttl[64], password[128]; unsigned int sleep_us, tries, ts_ttl; - if (argc == 1) { + int read_pw_from_stdin = 0; + if (argc > 1) + read_pw_from_stdin = strcmp(argv[1], "-") == 0; + + if (argc == 1 || (read_pw_from_stdin && argc == 2)) { printf("RootDO version: %s\n\n", VERSION); printf("Usage: %s [command]\n", argv[0]); return 0; @@ -64,7 +68,7 @@ int main(int argc, char** argv) { int ruid = getuid(); if (ruid == 0) - runprog(&argv[1]); + runprog(&argv[read_pw_from_stdin+1]); FILE* fp = fopen("/etc/rdo.conf", "r"); @@ -79,7 +83,7 @@ int main(int argc, char** argv) { fclose(fp); - if (getsession(getppid(), ts_ttl, ruid) == 0) + if (getsession(getppid(), ts_ttl, ruid) == 0 && !read_pw_from_stdin) runprog(&argv[1]); struct passwd* pw = getpwuid(ruid); @@ -117,7 +121,7 @@ int main(int argc, char** argv) { tries = 0; while (tries < 3) { - if (!readpassphrase("(rdo) Password: ", password, sizeof(password))) + if (!readpassphrase("(rdo) Password: ", password, sizeof(password), read_pw_from_stdin)) err(1, "Could not get passphrase"); char* hashed_pw = crypt(password, shadowEntry->sp_pwdp); @@ -127,8 +131,9 @@ int main(int argc, char** argv) { errx(1, "Could not hash password, does your user have a password?"); if (strcmp(shadowEntry->sp_pwdp, hashed_pw) == 0) { - setsession(getppid(), ts_ttl, ruid); - runprog(&argv[1]); + if (!read_pw_from_stdin) + setsession(getppid(), ts_ttl, ruid); + runprog(&argv[read_pw_from_stdin+1]); } usleep(sleep_us); diff --git a/readpassphrase.h b/readpassphrase.h index 7a7011b..d4f8be7 100644 --- a/readpassphrase.h +++ b/readpassphrase.h @@ -3,44 +3,58 @@ #include #include -char* readpassphrase(const char* prompt, char* buf, size_t bufsz) { +char* readpassphrase(const char* prompt, char* buf, size_t bufsz, int read_pw_from_stdin) { int n; - int ttyfd = -1; + int is_tty = !read_pw_from_stdin; + int infd = -1; + int outfd = -1; struct termios term; - for (int i = 0; i < 3; i++) { - if (tcgetattr(i, &term) == 0) { - ttyfd = i; - break; + if (read_pw_from_stdin) { + infd = 0; + outfd = 1; + is_tty = tcgetattr(outfd, &term) == 0; + } else { + for (int i = 0; i < 3; i++) { + if (tcgetattr(i, &term) == 0) { + infd = i; + outfd = i; + break; + } } } - - if (ttyfd < 0) + + if (infd < 0) return NULL; - term.c_lflag &= ~ECHO; - tcsetattr(ttyfd, 0, &term); - term.c_lflag |= ECHO; + if (is_tty) { + term.c_lflag &= ~ECHO; + tcsetattr(outfd, 0, &term); + term.c_lflag |= ECHO; + } - if (write(ttyfd, prompt, strlen(prompt)) < 0) { - tcsetattr(ttyfd, 0, &term); + if (write(outfd, prompt, strlen(prompt)) < 0) { + if (is_tty) + tcsetattr(outfd, 0, &term); return NULL; } - n = read(ttyfd, buf, bufsz); + n = read(infd, buf, bufsz); if (n < 0) { - tcsetattr(ttyfd, 0, &term); - n = write(ttyfd, "\n", 1); + if (is_tty) + tcsetattr(outfd, 0, &term); + n = write(outfd, "\n", 1); return NULL; } buf[n-1] = '\0'; // NOTE: As we disabled echo, the enter sent by the user isn't displayed, so we resend it. - n = write(ttyfd, "\n", 1); + n = write(outfd, "\n", 1); - tcsetattr(ttyfd, 0, &term); + if (is_tty) + tcsetattr(outfd, 0, &term); return buf; }