fig37.png fig39.png romania.png

#include <iostream>
#include <vector>
#include <queue>
#include <set>
#include <map>
#include <cassert>
using namespace std;
 
enum class City
{
    ORADEA,
    ZERIND,
    ARAD,
    TIMISOARA,
    LUGOJ,
    MEHADIA,
    DOBRETA,
    SIBIU,
    RIMNICU_VILCEA,
    CRAIOVA,
    FAGARAS,
    PITESTI,
    GIURGIU,
    BUCHAREST,
    NEAMT,
    URZICENI,
    IASI,
    VASLUI,
    HIRSOVA,
    EFORIE
};
string c2s(const City& c)
{
    if(c==City::ORADEA) return "Oradea";
    if(c==City::ZERIND) return "Zerind";
    if(c==City::ARAD) return "Arad";
    if(c==City::TIMISOARA) return "Timisoara";
    if(c==City::LUGOJ) return "Lugoj";
    if(c==City::MEHADIA) return "Mehadia";
    if(c==City::DOBRETA) return "Dobreta";
    if(c==City::SIBIU) return "Sibiu";
    if(c==City::RIMNICU_VILCEA) return "RimnicuVilcea";
    if(c==City::CRAIOVA) return "Craiova";
    if(c==City::FAGARAS) return "Fagaras";
    if(c==City::PITESTI) return "Pitesti";
    if(c==City::GIURGIU) return "Giurgiu";
    if(c==City::BUCHAREST) return "Bucharest";
    if(c==City::NEAMT) return "Neamt";
    if(c==City::URZICENI) return "Urziceni";
    if(c==City::IASI) return "Iasi";
    if(c==City::VASLUI) return "Vaslui";
    if(c==City::HIRSOVA) return "Hirsova";
    if(c==City::EFORIE) return "Eforie";
    return "???";
}
 
enum class Action
{
    TO_ORADEA,
    TO_ZERIND,
    TO_ARAD,
    TO_TIMISOARA,
    TO_LUGOJ,
    TO_MEHADIA,
    TO_DOBRETA,
    TO_SIBIU,
    TO_RIMNICU_VILCEA,
    TO_CRAIOVA,
    TO_FAGARAS,
    TO_PITESTI,
    TO_GIURGIU,
    TO_BUCHAREST,
    TO_NEAMT,
    TO_URZICENI,
    TO_IASI,
    TO_VASLUI,
    TO_HIRSOVA,
    TO_EFORIE
};
 
class Romania_map
{
    map<pair<City, City>, int> dist;
    map<City, vector<City> > neigh;
    void add_link(const City& c1, const City& c2, int d)
    {
        dist[ {c1, c2}] = d;
        dist[ {c2, c1}] = d;
        neigh[c1].push_back(c2);
        neigh[c2].push_back(c1);
    }
 
public:
    Romania_map()
    {
        add_link(City::ORADEA, City::ZERIND, 71);
        add_link(City::ORADEA, City::SIBIU, 151);
        add_link(City::ZERIND, City::ARAD, 75);
        add_link(City::ARAD, City::TIMISOARA, 118);
        add_link(City::ARAD, City::SIBIU, 140);
        add_link(City::TIMISOARA, City::LUGOJ, 111);
        add_link(City::LUGOJ, City::MEHADIA, 70);
        add_link(City::MEHADIA, City::DOBRETA, 75);
        add_link(City::DOBRETA, City::CRAIOVA, 120);
        add_link(City::SIBIU, City::FAGARAS, 99);
        add_link(City::SIBIU, City::RIMNICU_VILCEA, 80);
        add_link(City::RIMNICU_VILCEA, City::PITESTI, 97);
        add_link(City::RIMNICU_VILCEA, City::CRAIOVA, 146);
        add_link(City::CRAIOVA, City::PITESTI, 138);
        add_link(City::FAGARAS, City::BUCHAREST, 211);
        add_link(City::PITESTI, City::BUCHAREST, 101);
        add_link(City::GIURGIU, City::BUCHAREST, 90);
        add_link(City::BUCHAREST, City::URZICENI, 85);
        add_link(City::NEAMT, City::IASI, 87);
        add_link(City::URZICENI, City::VASLUI, 142);
        add_link(City::URZICENI, City::HIRSOVA, 98);
        add_link(City::IASI, City::VASLUI, 92);
        // add_link(VASLUI - already all linked
        add_link(City::HIRSOVA, City::EFORIE, 86);
        // add_link(EFORIE - already all linked
    }
    int get_dist(const City& c1, const City& c2) const
    {
        if(dist.count({c1,c2})==1)
            return dist.at({c1,c2});
        else
            return -1;
    }
    vector<City> get_neigh(const City& c) const
    {
        return neigh.at(c);
    }
    void list_dist()
    {
        cout << "size: " << dist.size() << endl;
        for(auto [key, value] : dist)
        {
            cout << c2s(key.first) << " " << c2s(key.second) << " : " << value << endl;
        }
    }
    void list_neigh()
    {
        cout << "size: " << neigh.size() << endl;
        for(auto [key, value] : neigh)
        {
            cout << c2s(key) << " : ";
            for(auto c : value)
                cout << c2s(c) << " ";
            cout << endl;
        }
    }
 
};
Romania_map romania; //global object :(
 
class State
{
    City city;
public:
    State()
    {
        city = City::BUCHAREST;
    }
    State(City c)
    {
        city = c;
    }
    City get_city() const
    {
        return city;
    }
    bool operator<(const State& s) const
    {
        return city<s.city;
    }
    bool operator==(const State& s) const
    {
        return city==s.city;
    }
    vector<Action> actions() const
    {
        vector<Action> as;
        vector<City> neigh = romania.get_neigh(city);
        for(City c : neigh)
            as.push_back(Action(c));
        return as;
    }
    void apply(const Action& a)
    {
        city = City(a);
    }
    void print() const
    {
        cout << "STATE: " << c2s(city) << endl;
    }
};
 
 
class Problem
{
    State initial_s;
    State goal_s;
public:
    Problem(const State& is, const State& gs)
    {
        initial_s = is;
        goal_s = gs;
    }
    State initial_state() const
    {
        return initial_s;
    }
    State goal_state() const
    {
        return goal_s;
    }
    bool is_goal(const State& s) const
    {
        return s==goal_s;
    }
    vector<Action> actions(const State& s) const
    {
        return s.actions();
    }
    State result(const State& s, const Action& a) const
    {
        State res(s);
        res.apply(a);
        return res;
    }
    int action_cost(const State& s, const Action& a, const State& sprim) const
    {
        assert(sprim.get_city() == City(a));
        return romania.get_dist(s.get_city(), sprim.get_city());
    }
};
 
class Node
{
public:
    State state;
    int path_cost;
    Node(const State& s, int path_c)
    {
        state = s;
        path_cost = path_c;
    }
};
 
vector<Node> expand(const Problem& problem, const Node& node)
{
    vector<Node> res;
    State s = node.state;
    vector<Action> actions = problem.actions(s);
    for(Action a : actions)
    {
        State sprim = problem.result(s, a);
        int cost = node.path_cost + problem.action_cost(s, a, sprim);
        res.push_back(Node(sprim, cost));
    }
    return res;
}
 
Node BFS(const Problem& problem)
{
    Node node = Node(problem.initial_state(), 0);
    if(problem.is_goal(node.state))
        return node;
 
    queue<Node> frontier;
    frontier.push(node);
    set<State> reached;
    reached.insert(problem.initial_state());
 
    while(!frontier.empty())
    {
        Node node = frontier.front();
        frontier.pop();
        for(Node child : expand(problem, node))
        {
            State s = child.state;
            if(problem.is_goal(s))
                return child;
            if(reached.count(s)==0)
            {
                reached.insert(s);
                frontier.push(child);
            }
        }
    }
    cout << "FAILURE" << endl;
    return Node(State(), -1);
}
 
int main()
{
 
    Problem problem(State(City::ARAD), State(City::BUCHAREST));
    Node node = BFS(problem);
    cout << node.path_cost << endl;
 
}