1   /*
2    * SPDX-FileCopyrightText: none
3    * SPDX-License-Identifier: CC0-1.0
4    */
5   
6   package gov.nist.secauto.oscal.lib.model.util;
7   
8   import gov.nist.secauto.metaschema.core.metapath.DynamicContext;
9   import gov.nist.secauto.metaschema.core.metapath.StaticContext;
10  import gov.nist.secauto.metaschema.core.metapath.item.ISequence;
11  import gov.nist.secauto.metaschema.core.metapath.item.node.AbstractRecursionPreventingNodeItemVisitor;
12  import gov.nist.secauto.metaschema.core.metapath.item.node.IAssemblyInstanceGroupedNodeItem;
13  import gov.nist.secauto.metaschema.core.metapath.item.node.IAssemblyNodeItem;
14  import gov.nist.secauto.metaschema.core.metapath.item.node.IDefinitionNodeItem;
15  import gov.nist.secauto.metaschema.core.metapath.item.node.IFieldNodeItem;
16  import gov.nist.secauto.metaschema.core.metapath.item.node.IFlagNodeItem;
17  import gov.nist.secauto.metaschema.core.metapath.item.node.IModuleNodeItem;
18  import gov.nist.secauto.metaschema.core.metapath.item.node.INodeItemFactory;
19  import gov.nist.secauto.metaschema.core.model.IModule;
20  import gov.nist.secauto.metaschema.core.model.constraint.IAllowedValuesConstraint;
21  import gov.nist.secauto.metaschema.core.model.constraint.ILet;
22  
23  import java.util.Collection;
24  import java.util.LinkedHashMap;
25  import java.util.LinkedList;
26  import java.util.List;
27  import java.util.Map;
28  
29  import edu.umd.cs.findbugs.annotations.NonNull;
30  
31  public class AllowedValueCollectingNodeItemVisitor
32      extends AbstractRecursionPreventingNodeItemVisitor<DynamicContext, Void> {
33  
34    private final Map<IDefinitionNodeItem<?, ?>, NodeItemRecord> nodeItemAnalysis = new LinkedHashMap<>();
35  
36    public Collection<NodeItemRecord> getAllowedValueLocations() {
37      return nodeItemAnalysis.values();
38    }
39  
40    public void visit(@NonNull IModule module) {
41      DynamicContext context = new DynamicContext(
42          StaticContext.builder()
43              .defaultModelNamespace(module.getXmlNamespace())
44              .build());
45      context.disablePredicateEvaluation();
46  
47      visit(INodeItemFactory.instance().newModuleNodeItem(module), context);
48    }
49  
50    public void visit(@NonNull IModuleNodeItem module, @NonNull DynamicContext context) {
51  
52      visitMetaschema(module, context);
53    }
54  
55    private void handleAllowedValuesAtLocation(
56        @NonNull IDefinitionNodeItem<?, ?> itemLocation,
57        @NonNull DynamicContext context) {
58      itemLocation.getDefinition().getAllowedValuesConstraints().stream()
59          .forEachOrdered(allowedValues -> {
60            ISequence<?> result = allowedValues.getTarget().evaluate(itemLocation, context);
61            result.stream().forEachOrdered(target -> {
62              assert target != null;
63              handleAllowedValues(allowedValues, itemLocation, (IDefinitionNodeItem<?, ?>) target);
64            });
65          });
66    }
67  
68    private void handleAllowedValues(
69        @NonNull IAllowedValuesConstraint allowedValues,
70        @NonNull IDefinitionNodeItem<?, ?> location,
71        @NonNull IDefinitionNodeItem<?, ?> target) {
72      NodeItemRecord itemRecord = nodeItemAnalysis.get(target);
73      if (itemRecord == null) {
74        itemRecord = new NodeItemRecord(target);
75        nodeItemAnalysis.put(target, itemRecord);
76      }
77  
78      AllowedValuesRecord allowedValuesRecord = new AllowedValuesRecord(allowedValues, location, target);
79      itemRecord.addAllowedValues(allowedValuesRecord);
80    }
81  
82    @Override
83    public Void visitFlag(IFlagNodeItem item, DynamicContext context) {
84      assert context != null;
85      DynamicContext subContext = handleLetStatements(item, context);
86      handleAllowedValuesAtLocation(item, subContext);
87      return super.visitFlag(item, subContext);
88    }
89  
90    @Override
91    public Void visitField(IFieldNodeItem item, DynamicContext context) {
92      assert context != null;
93      DynamicContext subContext = handleLetStatements(item, context);
94      handleAllowedValuesAtLocation(item, subContext);
95      return super.visitField(item, subContext);
96    }
97  
98    @Override
99    public Void visitAssembly(IAssemblyNodeItem item, DynamicContext context) {
100     assert context != null;
101     DynamicContext subContext = handleLetStatements(item, context);
102     handleAllowedValuesAtLocation(item, subContext);
103     return super.visitAssembly(item, subContext);
104   }
105 
106   private DynamicContext handleLetStatements(IDefinitionNodeItem<?, ?> item, DynamicContext context) {
107     assert context != null;
108     DynamicContext subContext = context;
109     for (ILet let : item.getDefinition().getLetExpressions().values()) {
110       ISequence<?> result = let.getValueExpression().evaluate(item,
111           subContext).reusable();
112       subContext = subContext.bindVariableValue(let.getName(), result);
113     }
114     return subContext;
115   }
116 
117   @Override
118   public Void visitAssembly(IAssemblyInstanceGroupedNodeItem item, DynamicContext context) {
119     return visitAssembly((IAssemblyNodeItem) item, context);
120   }
121 
122   @Override
123   protected Void defaultResult() {
124     return null;
125   }
126 
127   public static final class NodeItemRecord {
128     @NonNull
129     private final IDefinitionNodeItem<?, ?> item;
130     @NonNull
131     private final List<AllowedValuesRecord> allowedValues = new LinkedList<>();
132 
133     private NodeItemRecord(@NonNull IDefinitionNodeItem<?, ?> item) {
134       this.item = item;
135     }
136 
137     @NonNull
138     public IDefinitionNodeItem<?, ?> getItem() {
139       return item;
140     }
141 
142     @NonNull
143     public List<AllowedValuesRecord> getAllowedValues() {
144       return allowedValues;
145     }
146 
147     public void addAllowedValues(@NonNull AllowedValuesRecord record) {
148       this.allowedValues.add(record);
149     }
150   }
151 
152   public static final class AllowedValuesRecord {
153     @NonNull
154     private final IAllowedValuesConstraint allowedValues;
155     @NonNull
156     private final IDefinitionNodeItem<?, ?> location;
157     @NonNull
158     private final IDefinitionNodeItem<?, ?> target;
159 
160     public AllowedValuesRecord(
161         @NonNull IAllowedValuesConstraint allowedValues,
162         @NonNull IDefinitionNodeItem<?, ?> location,
163         @NonNull IDefinitionNodeItem<?, ?> target) {
164       this.allowedValues = allowedValues;
165       this.location = location;
166       this.target = target;
167     }
168 
169     @NonNull
170     public IAllowedValuesConstraint getAllowedValues() {
171       return allowedValues;
172     }
173 
174     @NonNull
175     public IDefinitionNodeItem<?, ?> getLocation() {
176       return location;
177     }
178 
179     @NonNull
180     public IDefinitionNodeItem<?, ?> getTarget() {
181       return target;
182     }
183   }
184 }