/* Copyright (c) 2017 Arun Muralidharan Permission is hereby granted, free of charge, to any person obtaining a copy of this software and associated documentation files (the "Software"), to deal in the Software without restriction, including without limitation the rights to use, copy, modify, merge, publish, distribute, sublicense, and/or sell copies of the Software, and to permit persons to whom the Software is furnished to do so, subject to the following conditions: The above copyright notice and this permission notice shall be included in all copies or substantial portions of the Software. THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE SOFTWARE. */ #ifndef JWT_IPP #define JWT_IPP #include "jwt/config.hpp" #include "jwt/detail/meta.hpp" #include #include namespace jwt { /** */ static inline void jwt_throw_exception(const std::error_code& ec); template std::string to_json_str(const T& obj, bool pretty) { return pretty ? obj.create_json_obj().dump(2) : obj.create_json_obj().dump() ; } template std::ostream& write(std::ostream& os, const T& obj, bool pretty) { pretty ? (os << std::setw(2) << obj.create_json_obj()) : (os << obj.create_json_obj()) ; return os; } template std::ostream& operator<< (std::ostream& os, const T& obj) { os << obj.create_json_obj(); return os; } //======================================================================== inline void jwt_header::decode(const jwt::string_view enc_str, std::error_code& ec) { ec.clear(); std::string json_str = base64_decode(enc_str); try { payload_ = json_t::parse(std::move(json_str)); } catch(const std::exception&) { ec = DecodeErrc::JsonParseError; return; } //Look for the algorithm field auto alg_itr = payload_.find("alg"); if (alg_itr == payload_.end()) { ec = DecodeErrc::AlgHeaderMiss; return; } alg_ = str_to_alg(alg_itr.value().get()); if (alg_ != algorithm::NONE) { auto itr = payload_.find("typ"); if (itr != payload_.end()) { const auto& typ = itr.value().get(); if (strcasecmp(typ.c_str(), "JWT")) { ec = DecodeErrc::TypMismatch; return; } typ_ = str_to_type(typ); } } else { //TODO: } // Populate header for (auto it = payload_.begin(); it != payload_.end(); ++it) { auto ret = headers_.insert(it.key()); if (!ret.second) { ec = DecodeErrc::DuplClaims; //ATTN: Dont stop the decode here //Not a hard error. } } return; } inline void jwt_header::decode(const jwt::string_view enc_str) { std::error_code ec; decode(enc_str, ec); if (ec) { throw DecodeError(ec.message()); } return; } inline void jwt_payload::decode(const jwt::string_view enc_str, std::error_code& ec) { ec.clear(); std::string json_str = base64_decode(enc_str); try { payload_ = json_t::parse(std::move(json_str)); } catch(const std::exception&) { ec = DecodeErrc::JsonParseError; return; } //populate the claims set for (auto it = payload_.begin(); it != payload_.end(); ++it) { auto ret = claim_names_.insert(it.key()); if (!ret.second) { ec = DecodeErrc::DuplClaims; break; } } return; } inline void jwt_payload::decode(const jwt::string_view enc_str) { std::error_code ec; decode(enc_str, ec); if (ec) { throw DecodeError(ec.message()); } return; } inline std::string jwt_signature::encode(const jwt_header& header, const jwt_payload& payload, std::error_code& ec) { std::string jwt_msg; ec.clear(); //TODO: Optimize allocations sign_func_t sign_fn = get_sign_algorithm_impl(header); std::string hdr_sign = header.base64_encode(); std::string pld_sign = payload.base64_encode(); std::string data = hdr_sign + '.' + pld_sign; auto res = sign_fn(key_, data); if (res.second && res.second != AlgorithmErrc::NoneAlgorithmUsed) { ec = res.second; return {}; } std::string b64hash; if (!res.second) { b64hash = base64_encode(res.first.c_str(), res.first.length()); } auto new_len = base64_uri_encode(&b64hash[0], b64hash.length()); b64hash.resize(new_len); jwt_msg = data + '.' + b64hash; return jwt_msg; } inline verify_result_t jwt_signature::verify(const jwt_header& header, const jwt::string_view hdr_pld_sign, const jwt::string_view jwt_sign) { verify_func_t verify_fn = get_verify_algorithm_impl(header); return verify_fn(key_, hdr_pld_sign, jwt_sign); } inline sign_func_t jwt_signature::get_sign_algorithm_impl(const jwt_header& hdr) const noexcept { sign_func_t ret = nullptr; switch (hdr.algo()) { case algorithm::HS256: ret = HMACSign::sign; break; case algorithm::HS384: ret = HMACSign::sign; break; case algorithm::HS512: ret = HMACSign::sign; break; case algorithm::NONE: ret = HMACSign::sign; break; case algorithm::RS256: ret = PEMSign::sign; break; case algorithm::RS384: ret = PEMSign::sign; break; case algorithm::RS512: ret = PEMSign::sign; break; case algorithm::ES256: ret = PEMSign::sign; break; case algorithm::ES384: ret = PEMSign::sign; break; case algorithm::ES512: ret = PEMSign::sign; break; default: assert (0 && "Code not reached"); }; return ret; } inline verify_func_t jwt_signature::get_verify_algorithm_impl(const jwt_header& hdr) const noexcept { verify_func_t ret = nullptr; switch (hdr.algo()) { case algorithm::HS256: ret = HMACSign::verify; break; case algorithm::HS384: ret = HMACSign::verify; break; case algorithm::HS512: ret = HMACSign::verify; break; case algorithm::NONE: ret = HMACSign::verify; break; case algorithm::RS256: ret = PEMSign::verify; break; case algorithm::RS384: ret = PEMSign::verify; break; case algorithm::RS512: ret = PEMSign::verify; break; case algorithm::ES256: ret = PEMSign::verify; break; case algorithm::ES384: ret = PEMSign::verify; break; case algorithm::ES512: ret = PEMSign::verify; break; default: assert (0 && "Code not reached"); }; return ret; } // template jwt_object::jwt_object( First&& first, Rest&&... rest) { static_assert (detail::meta::is_parameter_concept::value && detail::meta::are_all_params::value, "All constructor argument types must model ParameterConcept"); set_parameters(std::forward(first), std::forward(rest)...); } template void jwt_object::set_parameters( params::detail::payload_param&& payload, Rest&&... rargs) { for (const auto& elem : payload.get()) { payload_.add_claim(std::move(elem.first), std::move(elem.second)); } set_parameters(std::forward(rargs)...); } template void jwt_object::set_parameters( params::detail::secret_param secret, Rest&&... rargs) { secret_.assign(secret.get().data(), secret.get().length()); set_parameters(std::forward(rargs)...); } template void jwt_object::set_parameters( params::detail::algorithm_param alg, Rest&&... rargs) { header_.algo(alg.get()); set_parameters(std::forward(rargs)...); } template void jwt_object::set_parameters( params::detail::headers_param&& header, Rest&&... rargs) { for (const auto& elem : header.get()) { header_.add_header(std::move(elem.first), std::move(elem.second)); } set_parameters(std::forward(rargs)...); } inline void jwt_object::set_parameters() { //sentinel call return; } inline jwt_object& jwt_object::add_claim(const jwt::string_view name, system_time_t tp) { return add_claim( name, std::chrono::duration_cast< std::chrono::seconds>(tp.time_since_epoch()).count() ); } inline jwt_object& jwt_object::remove_claim(const jwt::string_view name) { payload_.remove_claim(name); return *this; } inline std::string jwt_object::signature(std::error_code& ec) const { ec.clear(); //key/secret should be set for any algorithm except NONE if (header().algo() != jwt::algorithm::NONE) { if (secret_.length() == 0) { ec = AlgorithmErrc::KeyNotFoundErr; return {}; } } jwt_signature jws{secret_}; return jws.encode(header_, payload_, ec); } inline std::string jwt_object::signature() const { std::error_code ec; std::string res = signature(ec); if (ec) { throw SigningError(ec.message()); } return res; } template std::error_code jwt_object::verify( const Params& dparams, const params::detail::algorithms_param& algos) const { std::error_code ec{}; //Verify if the algorithm set in the header //is any of the one expected by the client. auto fitr = std::find_if(algos.get().begin(), algos.get().end(), [this](const auto& elem) { return jwt::str_to_alg(elem) == this->header().algo(); }); if (fitr == algos.get().end()) { ec = VerificationErrc::InvalidAlgorithm; return ec; } //Check for the expiry timings if (has_claim(registered_claims::expiration)) { auto curr_time = std::chrono::duration_cast< std::chrono::seconds>(std::chrono::system_clock::now().time_since_epoch()).count(); auto p_exp = payload() .get_claim_value(registered_claims::expiration); if (static_cast(curr_time) > static_cast(p_exp + dparams.leeway)) { ec = VerificationErrc::TokenExpired; return ec; } } //Check for issuer if (dparams.has_issuer) { if (has_claim(registered_claims::issuer)) { const std::string& p_issuer = payload() .get_claim_value(registered_claims::issuer); if (p_issuer != dparams.issuer) { ec = VerificationErrc::InvalidIssuer; return ec; } } else { ec = VerificationErrc::InvalidIssuer; return ec; } } //Check for audience if (dparams.has_aud) { if (has_claim(registered_claims::audience)) { const std::string& p_aud = payload() .get_claim_value(registered_claims::audience); if (p_aud != dparams.aud) { ec = VerificationErrc::InvalidAudience; return ec; } } else { ec = VerificationErrc::InvalidAudience; return ec; } } //Check the subject if (dparams.has_sub) { if (has_claim(registered_claims::subject)) { const std::string& p_sub = payload() .get_claim_value(registered_claims::subject); if (p_sub != dparams.sub) { ec = VerificationErrc::InvalidSubject; return ec; } } else { ec = VerificationErrc::InvalidSubject; return ec; } } //Check for NBF if (has_claim(registered_claims::not_before)) { auto curr_time = std::chrono::duration_cast< std::chrono::seconds>(std::chrono::system_clock::now().time_since_epoch()).count(); auto p_exp = payload() .get_claim_value(registered_claims::not_before); if (static_cast(p_exp - dparams.leeway) > static_cast(curr_time)) { ec = VerificationErrc::ImmatureSignature; return ec; } } //Check IAT validation if (dparams.validate_iat) { if (!has_claim(registered_claims::issued_at)) { ec = VerificationErrc::InvalidIAT; return ec; } else { // Will throw type conversion error auto val = payload() .get_claim_value(registered_claims::issued_at); (void)val; } } //Check JTI validation if (dparams.validate_jti) { if (!has_claim("jti")) { ec = VerificationErrc::InvalidJTI; return ec; } } return ec; } inline std::array jwt_object::three_parts(const jwt::string_view enc_str) { std::array result; size_t fpos = enc_str.find_first_of('.'); assert (fpos != jwt::string_view::npos); result[0] = jwt::string_view{&enc_str[0], fpos}; size_t spos = enc_str.find_first_of('.', fpos + 1); result[1] = jwt::string_view{&enc_str[fpos + 1], spos - fpos - 1}; if (spos + 1 != enc_str.length()) { result[2] = jwt::string_view{&enc_str[spos + 1], enc_str.length() - spos - 1}; } return result; } template void jwt_object::set_decode_params(DecodeParams& dparams, params::detail::secret_param s, Rest&&... args) { dparams.secret.assign(s.get().data(), s.get().length()); dparams.has_secret = true; jwt_object::set_decode_params(dparams, std::forward(args)...); } template void jwt_object::set_decode_params(DecodeParams& dparams, params::detail::secret_function_param&& s, Rest&&... args) { dparams.secret = s.get(*dparams.payload_ptr); dparams.has_secret = true; jwt_object::set_decode_params(dparams, std::forward(args)...); } template void jwt_object::set_decode_params(DecodeParams& dparams, params::detail::leeway_param l, Rest&&... args) { dparams.leeway = l.get(); jwt_object::set_decode_params(dparams, std::forward(args)...); } template void jwt_object::set_decode_params(DecodeParams& dparams, params::detail::verify_param v, Rest&&... args) { dparams.verify = v.get(); jwt_object::set_decode_params(dparams, std::forward(args)...); } template void jwt_object::set_decode_params(DecodeParams& dparams, params::detail::issuer_param i, Rest&&... args) { dparams.issuer = std::move(i).get(); dparams.has_issuer = true; jwt_object::set_decode_params(dparams, std::forward(args)...); } template void jwt_object::set_decode_params(DecodeParams& dparams, params::detail::audience_param a, Rest&&... args) { dparams.aud = std::move(a).get(); dparams.has_aud = true; jwt_object::set_decode_params(dparams, std::forward(args)...); } template void jwt_object::set_decode_params(DecodeParams& dparams, params::detail::subject_param s, Rest&&... args) { dparams.sub = std::move(s).get(); dparams.has_sub = true; jwt_object::set_decode_params(dparams, std::forward(args)...); } template void jwt_object::set_decode_params(DecodeParams& dparams, params::detail::validate_iat_param v, Rest&&... args) { dparams.validate_iat = v.get(); jwt_object::set_decode_params(dparams, std::forward(args)...); } template void jwt_object::set_decode_params(DecodeParams& dparams, params::detail::validate_jti_param v, Rest&&... args) { dparams.validate_jti = v.get(); jwt_object::set_decode_params(dparams, std::forward(args)...); } template void jwt_object::set_decode_params(DecodeParams& dparams) { (void) dparams; // prevent -Wunused-parameter with gcc return; } //================================================================== template jwt_object decode(const jwt::string_view enc_str, const params::detail::algorithms_param& algos, std::error_code& ec, Args&&... args) { ec.clear(); jwt_object obj; if (algos.get().size() == 0) { ec = DecodeErrc::EmptyAlgoList; return obj; } struct decode_params { /// key to decode the JWS bool has_secret = false; std::string secret; /// Verify parameter. Defaulted to true. bool verify = true; /// Leeway parameter. Defaulted to zero seconds. uint32_t leeway = 0; ///The issuer //TODO: optional type bool has_issuer = false; std::string issuer; ///The audience //TODO: optional type bool has_aud = false; std::string aud; //The subject //TODO: optional type bool has_sub = false; std::string sub; //Validate IAT bool validate_iat = false; //Validate JTI bool validate_jti = false; const jwt_payload* payload_ptr = 0; }; decode_params dparams{}; //Signature must have atleast 2 dots auto dot_cnt = std::count_if(std::begin(enc_str), std::end(enc_str), [](char ch) { return ch == '.'; }); if (dot_cnt < 2) { ec = DecodeErrc::SignatureFormatError; return obj; } auto parts = jwt_object::three_parts(enc_str); //throws decode error jwt_header hdr{}; hdr.decode(parts[0], ec); if (ec) { return obj; } //obj.header(jwt_header{parts[0]}); obj.header(std::move(hdr)); //If the algorithm is not NONE, it must not //have more than two dots ('.') and the split //must result in three strings with some length. if (obj.header().algo() != jwt::algorithm::NONE) { if (dot_cnt > 2) { ec = DecodeErrc::SignatureFormatError; return obj; } if (parts[2].length() == 0) { ec = DecodeErrc::SignatureFormatError; return obj; } } //throws decode error jwt_payload payload{}; payload.decode(parts[1], ec); if (ec) { return obj; } obj.payload(std::move(payload)); dparams.payload_ptr = & obj.payload(); jwt_object::set_decode_params(dparams, std::forward(args)...); if (dparams.verify) { try { ec = obj.verify(dparams, algos); } catch (const json_ns::detail::type_error&) { ec = VerificationErrc::TypeConversionError; } if (ec) return obj; //Verify the signature only if some algorithm was used if (obj.header().algo() != algorithm::NONE) { if (!dparams.has_secret) { ec = DecodeErrc::KeyNotPresent; return obj; } jwt_signature jsign{dparams.secret}; // Length of the encoded header and payload only. // Addition of '1' to account for the '.' character. auto l = parts[0].length() + 1 + parts[1].length(); //MemoryAllocationError is not caught verify_result_t res = jsign.verify(obj.header(), enc_str.substr(0, l), parts[2]); if (res.second) { ec = res.second; return obj; } if (!res.first) { ec = VerificationErrc::InvalidSignature; return obj; } } else { ec = AlgorithmErrc::NoneAlgorithmUsed; } } return obj; } template jwt_object decode(const jwt::string_view enc_str, const params::detail::algorithms_param& algos, Args&&... args) { std::error_code ec{}; auto jwt_obj = decode(enc_str, algos, ec, std::forward(args)...); if (ec) { jwt_throw_exception(ec); } return jwt_obj; } void jwt_throw_exception(const std::error_code& ec) { const auto& cat = ec.category(); if (&cat == &theVerificationErrorCategory || std::string(cat.name()) == std::string(theVerificationErrorCategory.name())) { switch (static_cast(ec.value())) { case VerificationErrc::InvalidAlgorithm: { throw InvalidAlgorithmError(ec.message()); } case VerificationErrc::TokenExpired: { throw TokenExpiredError(ec.message()); } case VerificationErrc::InvalidIssuer: { throw InvalidIssuerError(ec.message()); } case VerificationErrc::InvalidAudience: { throw InvalidAudienceError(ec.message()); } case VerificationErrc::InvalidSubject: { throw InvalidSubjectError(ec.message()); } case VerificationErrc::InvalidIAT: { throw InvalidIATError(ec.message()); } case VerificationErrc::InvalidJTI: { throw InvalidJTIError(ec.message()); } case VerificationErrc::ImmatureSignature: { throw ImmatureSignatureError(ec.message()); } case VerificationErrc::InvalidSignature: { throw InvalidSignatureError(ec.message()); } case VerificationErrc::TypeConversionError: { throw TypeConversionError(ec.message()); } default: assert (0 && "Unknown error code"); }; } if (&cat == &theDecodeErrorCategory || std::string(cat.name()) == std::string(theDecodeErrorCategory.name())) { switch (static_cast(ec.value())) { case DecodeErrc::SignatureFormatError: { throw SignatureFormatError(ec.message()); } case DecodeErrc::KeyNotPresent: { throw KeyNotPresentError(ec.message()); } case DecodeErrc::KeyNotRequiredForNoneAlg: { // Not an error. Just to be ignored. break; } default: { throw DecodeError(ec.message()); } }; assert (0 && "Unknown error code"); } if (&cat == &theAlgorithmErrCategory || std::string(cat.name()) == std::string(theAlgorithmErrCategory.name())) { switch (static_cast(ec.value())) { case AlgorithmErrc::InvalidKeyErr: { throw InvalidKeyError(ec.message()); } case AlgorithmErrc::VerificationErr: { throw InvalidSignatureError(ec.message()); } case AlgorithmErrc::NoneAlgorithmUsed: { //Not an error actually. break; } default: assert (0 && "Unknown error code or not to be treated as an error"); }; } return; } } // END namespace jwt #endif