diff --git a/canvassyncer/__main__.py b/canvassyncer/__main__.py index 162692c..cd897ce 100644 --- a/canvassyncer/__main__.py +++ b/canvassyncer/__main__.py @@ -328,64 +328,56 @@ def initConfig(): oldConfig = json.load(open(CONFIG_PATH)) elif os.path.exists("./canvassyncer.json"): oldConfig = json.load(open("./canvassyncer.json")) + + def promptConfigStr(promptStr, key, *, defaultValOnMissing=None): + defaultVal = oldConfig.get(key) + if defaultVal is None: + if defaultValOnMissing is not None: + defaultVal = defaultValOnMissing + else: + defaultVal = "" + elif isinstance(defaultVal, list): + defaultVal = " ".join((str(val) for val in defaultVal)) + defaultVal = str(defaultVal) + tipStr = f"(Default: {defaultVal})" if defaultVal else "" + res = input(f"{promptStr}{tipStr}: ").strip() + if not res: + res = defaultVal + return res + print("Generating new config file...") - prevu = oldConfig.get("canvasURL", "") if oldConfig else "https://umjicanvas.com" - url = input("Canvas url(Defuault: " + prevu + "):").strip() - if not url: - url = prevu - tipStr = f"(Default: {oldConfig.get('token', '')})" if oldConfig else "" - token = input(f"Canvas access token{tipStr}:").strip() - if not token: - token = oldConfig.get("token", "") - tipStr = ( - f"(Default: {' '.join(oldConfig.get('courseCodes', []))})" if oldConfig else "" + url = promptConfigStr( + "Canvas url", "canvasURL", defaultValOnMissing="https://umjicanvas.com" ) - courseCodes = ( - input(f"Courses to sync in course codes(split with space){tipStr}:") - .strip() - .split() + token = promptConfigStr("Canvas access token", "token") + courseCodesStr = promptConfigStr( + "Courses to sync in course codes(split with space)", "courseCodes" ) - if not courseCodes: - courseCodes = oldConfig.get("courseCodes", []) - tipStr = ( - f"(Default: {' '.join([str(item) for item in oldConfig.get('courseIDs', [])])})" - if oldConfig - else "" + courseCodes = courseCodesStr.split() + courseIDsStr = promptConfigStr( + "Courses to sync in course ID(split with space)", "courseIDs" ) - courseIDs = ( - input(f"Courses to sync in course ID(split with space){tipStr}:") - .strip() - .split() + courseIDs = [int(courseID) for courseID in courseIDsStr.split()] + downloadDir = promptConfigStr( + "Path to save canvas files", + "downloadDir", + defaultValOnMissing=os.path.abspath(""), ) - if not courseIDs: - courseIDs = oldConfig.get("courseIDs", []) - courseIDs = [int(courseID) for courseID in courseIDs] - tipStr = f"(Default: {oldConfig.get('downloadDir', os.path.abspath(''))})" - downloadDir = input(f"Path to save canvas files{tipStr}:").strip() - if not downloadDir: - downloadDir = oldConfig.get("downloadDir", os.path.abspath("")) - tipStr = ( - f"(Default: {oldConfig.get('filesizeThresh', '')})" - if oldConfig - else f"(Default: 250)" + filesizeThreshStr = promptConfigStr( + "Maximum file size to download(MB)", "filesizeThresh", defaultValOnMissing=250 ) - filesizeThresh = input(f"Maximum file size to download(MB){tipStr}:").strip() try: - filesizeThresh = float(filesizeThresh) + filesizeThresh = float(filesizeThreshStr) except Exception: filesizeThresh = 250 - json.dump( - { - "canvasURL": url, - "token": token, - "courseCodes": courseCodes, - "courseIDs": courseIDs, - "downloadDir": downloadDir, - "filesizeThresh": filesizeThresh, - }, - open(CONFIG_PATH, mode="w", encoding="utf-8"), - indent=4, - ) + return { + "canvasURL": url, + "token": token, + "courseCodes": courseCodes, + "courseIDs": courseIDs, + "downloadDir": downloadDir, + "filesizeThresh": filesizeThresh, + } def getConfig(): @@ -423,15 +415,17 @@ def getConfig(): if not os.path.exists(configPath): print("Config file not exist, creating...") try: - initConfig() + json.dump( + initConfig(), + open(configPath, mode="w", encoding="utf-8"), + indent=4, + ) except Exception as e: print(f"\nError: {e.__class__.__name__}. Failed to create config file.") if args.debug: print(traceback.format_exc()) exit(1) - if args.r: - return - config = json.load(open(configPath, "r", encoding="UTF-8")) + config = json.load(open(configPath, mode="r", encoding="utf-8")) config["y"] = args.y config["proxies"] = args.proxy config["no_subfolder"] = args.no_subfolder @@ -463,10 +457,9 @@ async def sync(): except Exception as e: errorName = e.__class__.__name__ print( - f"Unexpected error: {errorName}. Please check your network and token!", - end="", + f"Unexpected error: {errorName}. Please check your network and token!" + + ("" if config["debug"] else " Or use -d for detailed information.") ) - print("" if config["debug"] else " Or use -d for detailed information.") if config["debug"]: print(traceback.format_exc()) finally: